|
| 1 | +from lightfm import LightFM |
| 2 | +from lightfm.datasets import fetch_movielens |
| 3 | +from pgvector.sqlalchemy import Vector |
| 4 | +from sqlalchemy import create_engine, Column, Float, Integer, String |
| 5 | +from sqlalchemy.orm import declarative_base, Session |
| 6 | + |
| 7 | +engine = create_engine("postgresql+psycopg2://localhost/pgvector_test", future=True) |
| 8 | + |
| 9 | +Base = declarative_base() |
| 10 | + |
| 11 | + |
| 12 | +class User(Base): |
| 13 | + __tablename__ = 'user' |
| 14 | + |
| 15 | + id = Column(Integer, primary_key=True) |
| 16 | + factors = Column(Vector(20)) |
| 17 | + |
| 18 | + |
| 19 | +class Item(Base): |
| 20 | + __tablename__ = 'item' |
| 21 | + |
| 22 | + id = Column(Integer, primary_key=True) |
| 23 | + title = Column(String) |
| 24 | + factors = Column(Vector(20)) |
| 25 | + bias = Column(Float) |
| 26 | + |
| 27 | + |
| 28 | +Base.metadata.drop_all(engine) |
| 29 | +Base.metadata.create_all(engine) |
| 30 | + |
| 31 | +data = fetch_movielens(min_rating=5.0) |
| 32 | +model = LightFM(loss='warp', no_components=20) |
| 33 | +model.fit(data['train'], epochs=30) |
| 34 | + |
| 35 | +user_biases, user_factors = model.get_user_representations() |
| 36 | +item_biases, item_factors = model.get_item_representations() |
| 37 | + |
| 38 | +users = [dict(id=i, factors=factors) for i, factors in enumerate(user_factors)] |
| 39 | +items = [dict(id=i, title=data['item_labels'][i], factors=factors, bias=item_biases[i].item()) for i, factors in enumerate(item_factors)] |
| 40 | + |
| 41 | +session = Session(engine) |
| 42 | +session.bulk_insert_mappings(User, users) |
| 43 | +session.bulk_insert_mappings(Item, items) |
| 44 | +session.commit() |
| 45 | + |
| 46 | +user = session.query(User).get(1) |
| 47 | +# subtract item bias for negative inner product |
| 48 | +items = session.query(Item).order_by(Item.factors.max_inner_product(user.factors) - Item.bias).limit(5).all() |
| 49 | +print('user-based recs:', [item.title for item in items]) |
| 50 | + |
| 51 | +item = session.query(Item).filter(Item.title == 'Star Wars (1977)').first() |
| 52 | +items = session.query(Item).filter(Item.id != item.id).order_by(Item.factors.cosine_distance(item.factors)).limit(5).all() |
| 53 | +print('item-based recs:', [item.title for item in items]) |
0 commit comments