Skip to content

Commit bdd9429

Browse files
committed
Added LightFM example [skip ci]
1 parent 9fad8fc commit bdd9429

2 files changed

Lines changed: 54 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Or check out some examples:
2626

2727
- [Implicit feedback recommendations](examples/implicit_recs.py) with Implicit
2828
- [Explicit feedback recommendations](examples/surprise_recs.py) with Surprise
29+
- [Recommendations](examples/lightfm_recs.py) with LightFM
2930

3031
## Django
3132

examples/lightfm_recs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from lightfm import LightFM
2+
from lightfm.datasets import fetch_movielens
3+
from pgvector.sqlalchemy import Vector
4+
from sqlalchemy import create_engine, Column, Float, Integer, String
5+
from sqlalchemy.orm import declarative_base, Session
6+
7+
engine = create_engine("postgresql+psycopg2://localhost/pgvector_test", future=True)
8+
9+
Base = declarative_base()
10+
11+
12+
class User(Base):
13+
__tablename__ = 'user'
14+
15+
id = Column(Integer, primary_key=True)
16+
factors = Column(Vector(20))
17+
18+
19+
class Item(Base):
20+
__tablename__ = 'item'
21+
22+
id = Column(Integer, primary_key=True)
23+
title = Column(String)
24+
factors = Column(Vector(20))
25+
bias = Column(Float)
26+
27+
28+
Base.metadata.drop_all(engine)
29+
Base.metadata.create_all(engine)
30+
31+
data = fetch_movielens(min_rating=5.0)
32+
model = LightFM(loss='warp', no_components=20)
33+
model.fit(data['train'], epochs=30)
34+
35+
user_biases, user_factors = model.get_user_representations()
36+
item_biases, item_factors = model.get_item_representations()
37+
38+
users = [dict(id=i, factors=factors) for i, factors in enumerate(user_factors)]
39+
items = [dict(id=i, title=data['item_labels'][i], factors=factors, bias=item_biases[i].item()) for i, factors in enumerate(item_factors)]
40+
41+
session = Session(engine)
42+
session.bulk_insert_mappings(User, users)
43+
session.bulk_insert_mappings(Item, items)
44+
session.commit()
45+
46+
user = session.query(User).get(1)
47+
# subtract item bias for negative inner product
48+
items = session.query(Item).order_by(Item.factors.max_inner_product(user.factors) - Item.bias).limit(5).all()
49+
print('user-based recs:', [item.title for item in items])
50+
51+
item = session.query(Item).filter(Item.title == 'Star Wars (1977)').first()
52+
items = session.query(Item).filter(Item.id != item.id).order_by(Item.factors.cosine_distance(item.factors)).limit(5).all()
53+
print('item-based recs:', [item.title for item in items])

0 commit comments

Comments
 (0)