Skip to content

Commit e8c3882

Browse files
uploading rag demo
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 443c130 commit e8c3882

File tree

9 files changed

+415
-0
lines changed

9 files changed

+415
-0
lines changed

module_4_rag/Dockerfile

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
FROM python:3.9
2+
3+
# Set environment varibles
4+
ENV PYTHONDONTWRITEBYTECODE 1
5+
ENV PYTHONUNBUFFERED 1
6+
7+
# Set work directory
8+
WORKDIR /code
9+
10+
11+
# Install dependencies
12+
RUN LIBMEMCACHED=/opt/local
13+
RUN apt-get update && apt-get install -y \
14+
libmemcached11 \
15+
libmemcachedutil2 \
16+
libmemcached-dev \
17+
libz-dev \
18+
curl \
19+
gettext
20+
21+
ENV PYTHONHASHSEED=random \
22+
PIP_NO_CACHE_DIR=off \
23+
PIP_DISABLE_PIP_VERSION_CHECK=on \
24+
PIP_DEFAULT_TIMEOUT=100 \
25+
# Poetry's configuration: \
26+
POETRY_NO_INTERACTION=1 \
27+
POETRY_VIRTUALENVS_CREATE=false \
28+
POETRY_CACHE_DIR='/var/cache/pypoetry' \
29+
POETRY_HOME='/usr/local' \
30+
POETRY_VERSION=1.4.1
31+
32+
RUN curl -sSL https://install.python-poetry.org | python3 - --version $POETRY_VERSION
33+
34+
COPY pyproject.toml poetry.lock /code/
35+
RUN poetry install --no-interaction --no-ansi --no-root
36+
37+
COPY . ./code/

module_4_rag/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
This is a demo to show how you can use Feast to do RAG
2+
3+
## Installation via PyEnv and Poetry
4+
5+
This demo assumes you have Pyenv (2.3.10) and Poetry (1.4.1) installed on your machine as well as Python 3.9.
6+
7+
```bash
8+
pyenv local 3.9
9+
poetry shell
10+
poetry install
11+
```
12+
## Setting up the data and Feast
13+
14+
To fetch the data simply run
15+
```bash
16+
python pull_states.py
17+
```
18+
Which will output a file called `city_wikipedia_summaries.csv`.
19+
20+
Then run
21+
```bash
22+
python batch_score_documents.py
23+
```
24+
25+
# Overview
26+
27+
The overview is relatively simple, the goal is to define an architecture
28+
to support the following:
29+
30+
```mermaid
31+
flowchart TD;
32+
A[Pull Data] --> B[Batch Score Embeddings];
33+
B[Batch Score Embeddings] --> C[Materialize Online];
34+
C[Materialize Online] --> D[Retrieval Augmented Generation];
35+
D[Retrieval Augmented Generation] --> E[Store User Interaction];
36+
E[Store User Interaction] --> F[Update Training Labels];
37+
F[Update Training Labels] --> H[Fine Tuning];
38+
H[Fine Tuning] -. Backpropagate .-> B[Batch Score Embeddings];
39+
```
40+
41+
42+
A simple example of the user experience:
43+
44+
```
45+
Q: Can you tell me about Chicago?
46+
A: Here's some wikipedia facts about Chicago...
47+
```
48+
49+
# Limitations
50+
A common issue with RAG and LLMs is hallucination. There are two common
51+
approaches:
52+
53+
1. Prompt engineering
54+
- This approach is the most obvious but is susceptible to prompt injection
55+
56+
2. Build a Classifier to return the "I don't know" response
57+
- This approach is less obvious, requires another model, more training data,
58+
and fine tuning
59+
60+
We can, in fact, use both approaches to further attempt to minimize the
61+
likelihood of prompt injection.
62+
63+
This demo will display both.
64+

module_4_rag/app.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from flask import (
2+
Flask,
3+
jsonify,
4+
request,
5+
render_template,
6+
)
7+
from flasgger import Swagger
8+
from datetime import datetime
9+
from get_features import (
10+
get_onboarding_features,
11+
get_onboarding_score,
12+
get_daily_features,
13+
get_daily_score,
14+
)
15+
from ml import make_risk_decision
16+
17+
app = Flask(__name__)
18+
swagger = Swagger(app)
19+
20+
21+
@app.route("/")
22+
def onboarding_page():
23+
return render_template("index.html")
24+
25+
26+
@app.route("/home")
27+
def home_page():
28+
return render_template("home.html")
29+
30+
31+
@app.route("/onboarding-risk-features/", methods=["POST"])
32+
def onboarding_features():
33+
"""Example endpoint returning features by id
34+
This is using docstrings for specifications.
35+
---
36+
parameters:
37+
- name: state
38+
type: string
39+
in: query
40+
required: true
41+
default: NJ
42+
43+
- name: ssn
44+
type: string
45+
in: query
46+
required: true
47+
default: 123-45-6789
48+
49+
- name: dl
50+
type: string
51+
in: query
52+
required: true
53+
default: some-dl-number
54+
55+
- name: dob
56+
type: string
57+
in: query
58+
required: true
59+
default: 12-23-2000
60+
responses:
61+
200:
62+
description: A JSON of features
63+
schema:
64+
id: OnboardingFeatures
65+
properties:
66+
is_gt_18_years_old:
67+
type: array
68+
items:
69+
schema:
70+
id: value
71+
type: number
72+
is_valid_state:
73+
type: array
74+
items:
75+
schema:
76+
id: value
77+
type: number
78+
is_previously_seen_ssn:
79+
type: array
80+
items:
81+
schema:
82+
id: value
83+
type: number
84+
is_previously_seen_dl:
85+
type: array
86+
items:
87+
schema:
88+
id: value
89+
type: number
90+
"""
91+
r = request.args
92+
feature_vector = get_onboarding_features(
93+
r.get("state"), r.get("ssn"), r.get("dl"), r.get("dob")
94+
)
95+
return jsonify(feature_vector)
96+
97+
98+
if __name__ == "__main__":
99+
app.run(debug=True)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import pandas as pd
3+
from transformers import AutoTokenizer, AutoModel
4+
import torch
5+
import torch.nn.functional as F
6+
7+
INPUT_FILENAME = "city_wikipedia_summaries.csv"
8+
EXPORT_FILENAME = "city_wikipedia_summaries_with_embeddings.csv"
9+
TOKENIZER = 'sentence-transformers/all-MiniLM-L6-v2'
10+
MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
11+
12+
def mean_pooling(model_output, attention_mask):
13+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
14+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
15+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
16+
17+
def run_model(sentences, tokenizer, model):
18+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
19+
# Compute token embeddings
20+
with torch.no_grad():
21+
model_output = model(**encoded_input)
22+
23+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
24+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
25+
return sentence_embeddings
26+
27+
def score_data() -> None:
28+
if EXPORT_FILENAME not in os.listdir():
29+
print("scored data not found...generating embeddings...")
30+
df = pd.read_csv(INPUT_FILENAME)
31+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
32+
model = AutoModel.from_pretrained(MODEL)
33+
embeddings = run_model(df['Wiki Summary'].tolist(), tokenizer, model)
34+
print(embeddings)
35+
print('shape = ', df.shape)
36+
df['Embeddings'] = list(embeddings.detach().cpu().numpy())
37+
print("embeddings generated...")
38+
print(df.head())
39+
df.to_csv(EXPORT_FILENAME, index=False)
40+
print("...data exported. job complete")
41+
else:
42+
print("scored data found...skipping generating embeddings.")
43+
44+
if __name__ == '__main__':
45+
score_data()

module_4_rag/docker-compose.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
version: '3.9'
2+
3+
services:
4+
web:
5+
env_file:
6+
- .env
7+
build: .
8+
command:
9+
- /bin/bash
10+
- -c
11+
- python3 /code/run.py
12+
13+
volumes:
14+
- .:/code
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import csv
2+
import random
3+
4+
topics = ["science", "history", "technology", "mathematics", "geography", "literature", "sports", "art", "music", "cinema"]
5+
6+
# Define a pattern for generating questions
7+
question_patterns = [
8+
"What are the key principles of {}?",
9+
"Who are the most influential figures in {}?",
10+
"How has {} evolved over the years?",
11+
"What are some common misconceptions about {}?",
12+
"Can you explain the theory of {}?",
13+
"What role does {} play in modern society?",
14+
"How does {} affect our daily lives?",
15+
"What are the future prospects of {}?",
16+
"What are the major challenges in {} today?",
17+
"How can one get started with {}?"
18+
]
19+
20+
# Generate a list of 50 random questions
21+
questions = []
22+
for _ in range(50):
23+
topic = random.choice(topics)
24+
pattern = random.choice(question_patterns)
25+
question = pattern.format(topic)
26+
questions.append([question])
27+
28+
29+
def main():
30+
# Define the file path
31+
file_path = './random_questions.csv'
32+
33+
# Write the questions to a CSV file
34+
with open(file_path, 'w', newline='') as file:
35+
writer = csv.writer(file)
36+
writer.writerow(["Question"]) # Writing header
37+
writer.writerows(questions)
38+
39+
if __name__ == "__main__":
40+
main()

module_4_rag/pull_states.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
from typing import Dict, List
3+
import wikipedia as wiki
4+
import pandas as pd
5+
6+
EXPORT_FILENAME = "city_wikipedia_summaries.csv"
7+
CITIES = [
8+
"New York, New York",
9+
"Los Angeles, California",
10+
"Chicago, Illinois",
11+
"Houston, Texas",
12+
"Phoenix, Arizona",
13+
"Philadelphia, Pennsylvania",
14+
"San Antonio, Texas",
15+
"San Diego, California",
16+
"Dallas, Texas",
17+
"San Jose, California",
18+
"Austin, Texas",
19+
"Jacksonville, Florida",
20+
"Fort Worth, Texas",
21+
"Columbus, Ohio",
22+
"Charlotte, North Carolina",
23+
"San Francisco, California",
24+
"Indianapolis, Indiana",
25+
"Seattle, Washington",
26+
"Denver, Colorado",
27+
"Washington, D.C.",
28+
"Boston, Massachusetts",
29+
"El Paso, Texas",
30+
"Nashville, Tennessee",
31+
"Detroit, Michigan",
32+
"Oklahoma City, Oklahoma",
33+
"Portland, Oregon",
34+
"Las Vegas, Nevada",
35+
"Memphis, Tennessee",
36+
"Louisville, Kentucky",
37+
"Baltimore, Maryland",
38+
"Milwaukee, Wisconsin",
39+
"Albuquerque, New Mexico",
40+
"Tucson, Arizona",
41+
"Fresno, California",
42+
"Mesa, Arizona",
43+
"Sacramento, California",
44+
"Atlanta, Georgia",
45+
"Kansas City, Missouri",
46+
"Colorado Springs, Colorado",
47+
"Miami, Florida",
48+
"Raleigh, North Carolina",
49+
"Omaha, Nebraska",
50+
"Long Beach, California",
51+
"Virginia Beach, Virginia",
52+
"Oakland, California",
53+
"Minneapolis, Minnesota",
54+
"Tulsa, Oklahoma",
55+
"Arlington, Texas",
56+
"Tampa, Florida",
57+
"New Orleans, Louisiana"
58+
]
59+
60+
def get_wikipedia_summary(cities: List[str]) -> Dict[str, str]:
61+
city_summaries = {}
62+
for city in cities:
63+
try:
64+
city_summaries[city] = wiki.summary(city)
65+
except:
66+
print(f"error retrieving {city}")
67+
68+
return city_summaries
69+
70+
71+
def write_data(output_dict: Dict[str, str]) -> None:
72+
df = pd.DataFrame([output_dict]).T.reset_index()
73+
df.columns = ['State', 'Wiki Summary']
74+
df.to_csv(EXPORT_FILENAME, index=False)
75+
76+
def pull_state_data() -> None:
77+
if EXPORT_FILENAME not in os.listdir():
78+
print("data not found pullling wikipedia state summaries...")
79+
city_summary_output = get_wikipedia_summary(CITIES)
80+
write_data(city_summary_output)
81+
else:
82+
print("data already present...skipping download")
83+
84+
if __name__ == "__main__":
85+
pull_state_data()

0 commit comments

Comments
 (0)