Skip to content

Commit 3fdb077

Browse files
committed
Updated example for SQLAlchemy 2 [skip ci]
1 parent 141ca8f commit 3fdb077

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

examples/lightfm_recs.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from lightfm import LightFM
22
from lightfm.datasets import fetch_movielens
33
from pgvector.sqlalchemy import Vector
4-
from sqlalchemy import create_engine, text, Column, Float, Integer, String
5-
from sqlalchemy.orm import declarative_base, Session
4+
from sqlalchemy import create_engine, select, text, Float, Integer, String
5+
from sqlalchemy.orm import declarative_base, mapped_column, Session
66

7-
engine = create_engine('postgresql+psycopg2://localhost/pgvector_example', future=True)
7+
engine = create_engine('postgresql+psycopg2://localhost/pgvector_example')
88
with engine.connect() as conn:
99
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
1010
conn.commit()
@@ -15,17 +15,17 @@
1515
class User(Base):
1616
__tablename__ = 'user'
1717

18-
id = Column(Integer, primary_key=True)
19-
factors = Column(Vector(20))
18+
id = mapped_column(Integer, primary_key=True)
19+
factors = mapped_column(Vector(20))
2020

2121

2222
class Item(Base):
2323
__tablename__ = 'item'
2424

25-
id = Column(Integer, primary_key=True)
26-
title = Column(String)
27-
factors = Column(Vector(20))
28-
bias = Column(Float)
25+
id = mapped_column(Integer, primary_key=True)
26+
title = mapped_column(String)
27+
factors = mapped_column(Vector(20))
28+
bias = mapped_column(Float)
2929

3030

3131
Base.metadata.drop_all(engine)
@@ -46,11 +46,11 @@ class Item(Base):
4646
session.bulk_insert_mappings(Item, items)
4747
session.commit()
4848

49-
user = session.query(User).get(1)
49+
user = session.get(User, 1)
5050
# subtract item bias for negative inner product
51-
items = session.query(Item).order_by(Item.factors.max_inner_product(user.factors) - Item.bias).limit(5).all()
51+
items = session.scalars(select(Item).order_by(Item.factors.max_inner_product(user.factors) - Item.bias).limit(5))
5252
print('user-based recs:', [item.title for item in items])
5353

54-
item = session.query(Item).filter(Item.title == 'Star Wars (1977)').first()
55-
items = session.query(Item).filter(Item.id != item.id).order_by(Item.factors.cosine_distance(item.factors)).limit(5).all()
54+
item = session.scalars(select(Item).filter(Item.title == 'Star Wars (1977)')).first()
55+
items = session.scalars(select(Item).filter(Item.id != item.id).order_by(Item.factors.cosine_distance(item.factors)).limit(5))
5656
print('item-based recs:', [item.title for item in items])

0 commit comments

Comments
 (0)