Skip to content

Commit f98e50c

Browse files
committed
Addressing feedback and fixing tests
Signed-off-by: Fiona Waters <fiwaters6@gmail.com>
1 parent 455e3e5 commit f98e50c

File tree

6 files changed

+60
-80
lines changed

6 files changed

+60
-80
lines changed

examples/rag-retriever/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# End-to-end RAG example using Feast, Milvus, and OpenShift AI.
1+
# End-to-end RAG example using Feast and Milvus.
22

33
## Introduction
4-
This example notebook provides a step-by-step demonstration of building and using a RAG system with Feast Feature Store and the custom FeastRagRetriever, on OpenShift AI. The notebook walks through:
4+
This example notebook provides a step-by-step demonstration of building and using a RAG system with Feast Feature Store and the custom FeastRagRetriever. The notebook walks through:
55

66
1. Data Preparation
77
- Loads a subset of the Wikipedia DPR dataset (1% of training data)
@@ -28,13 +28,13 @@ This example notebook provides a step-by-step demonstration of building and usin
2828
- Perform inference with retrieved context
2929

3030
## Requirements
31-
- An OpenShift cluster with OpenShift AI (RHOAI) 2.20+ installed:
32-
- The dashboard, feastoperator and workbenches components enabled.
33-
- Workbench with medium size container, 1 NVIDIA GPU accelerator, and cluster storage of 200GB.
34-
- A standalone Milvus deployment. See example [here](https://github.com/rh-aiservices-bu/llm-on-openshift/tree/main/vector-databases/milvus#deployment).
31+
- A Kubernetes cluster with:
32+
- GPU nodes available (for model inference)
33+
- At least 200GB of storage
34+
- A standalone Milvus deployment. See example [here](https://github.com/milvus-io/milvus-helm/tree/master/charts/milvus).
3535

3636
## Running the example
37-
From the workbench, clone this repository: https://github.com/feast-dev/feast.git
37+
Clone this repository: https://github.com/feast-dev/feast.git
3838
Navigate to the examples/rag-retriever directory. Here you will find the following files:
3939

4040
* **feature_repo/feature_store.yaml**
@@ -48,7 +48,7 @@ Navigate to the examples/rag-retriever directory. Here you will find the followi
4848
* **__feature_repo/ragproject_repo.py__**
4949
This is the Feast feature repository configuration that defines the schema and data source for Wikipedia passage embeddings.
5050

51-
* **__rag_feast_kfto.ipynb__**
51+
* **__rag_feast.ipynb__**
5252
This is a notebook demonstrating the implementation of a RAG system using Feast feature store. The notebook provides:
5353

5454
- A complete end-to-end example of building a RAG system with:
@@ -60,7 +60,7 @@ Navigate to the examples/rag-retriever directory. Here you will find the followi
6060
- Uses `all-MiniLM-L6-v2` for generating embeddings
6161
- Implements `granite-3.2-2b-instruct` as the generator model
6262

63-
Open `rag_feast_kfto.ipynb` and follow the steps in the notebook to run the example.
63+
Open `rag_feast.ipynb` and follow the steps in the notebook to run the example.
6464

6565
## FeastRagRetriver Low Level Design
6666

sdk/python/feast/rag_retriever.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@ def __init__(self):
3333
"""Initialize the Feast index."""
3434
pass
3535

36-
def get_top_docs(self, query_vectors: np.ndarray, n_docs: int = 5):
37-
"""Get top documents (not implemented).
38-
39-
This method is required by the RagRetriever interface but is not used
40-
as we override the retrieve method in FeastRAGRetriever.
41-
"""
42-
raise NotImplementedError("get_top_docs is not yet implemented.")
43-
44-
def get_doc_dicts(self, doc_ids: List[str]):
45-
"""Get document dictionaries (not implemented).
46-
47-
This method is required by the RagRetriever interface but is not used
48-
as we override the retrieve method in FeastRAGRetriever.
49-
"""
50-
raise NotImplementedError("get_doc_dicts is not yet implemented.")
51-
5236

5337
class FeastRAGRetriever(RagRetriever):
5438
"""RAG retriever implementation that uses Feast as a backend."""
@@ -114,12 +98,6 @@ def __init__(
11498
**kwargs,
11599
)
116100

117-
# if torch.cuda.is_available():
118-
# self.question_encoder.to(torch.device("cuda"))
119-
# self.generator_model.to(torch.device("cuda"))
120-
# self.question_encoder = question_encoder
121-
# self.generator_model = generator_model
122-
# self.generator_tokenizer = generator_tokenizer
123101
self.feast_repo_path = feast_repo_path
124102
self.search_type = search_type.lower()
125103
self.format_document = format_document or self._default_format_document

sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def test_dynamodb_online_store_update(repo_config, dynamodb_online_store):
198198
assert existing_tables is not None
199199
assert len(existing_tables) == 1
200200
assert existing_tables[0] == f"test_aws.{db_table_keep_name}"
201-
202201
assert _get_tags(dynamodb_client, existing_tables[0]) == [
203202
{"Key": "some", "Value": "tag"}
204203
]

sdk/python/tests/unit/test_rag_retriever.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from unittest.mock import MagicMock, Mock
2+
from unittest.mock import MagicMock, Mock, patch
33

44
import numpy as np
55
import pytest
@@ -18,7 +18,7 @@
1818
from tests.utils.rag_test_utils import MockVectorStore
1919

2020

21-
@pytest.fixture(autouse=True)
21+
@pytest.fixture
2222
def cleanup_milvus():
2323
"""Ensure Milvus resources are cleaned up between tests."""
2424
yield
@@ -104,7 +104,7 @@ def forward(self, **kwargs):
104104

105105

106106
@pytest.fixture
107-
def feature_store(temp_dir):
107+
def simple_feature_store(temp_dir):
108108
"""Create a simple feature store without auth for RAG tests."""
109109
return FeatureStore(
110110
config=RepoConfig(
@@ -117,7 +117,7 @@ def feature_store(temp_dir):
117117

118118

119119
@pytest.fixture
120-
def rag_retriever(feature_store):
120+
def rag_retriever(simple_feature_store):
121121
"""Create a RAG retriever instance for testing."""
122122
# Import the required objects
123123
from tests.example_repos.example_feature_repo_1 import (
@@ -137,7 +137,7 @@ def rag_retriever(feature_store):
137137
question_encoder=MockModel(),
138138
generator_tokenizer=MockTokenizer(),
139139
generator_model=MockModel(),
140-
feast_repo_path=str(feature_store.repo_path),
140+
feast_repo_path=str(simple_feature_store.repo_path),
141141
feature_view=document_embeddings,
142142
features=[
143143
"document_embeddings:content",
@@ -153,7 +153,7 @@ def rag_retriever(feature_store):
153153

154154
# Replace the vector store with our mock
155155
retriever._vector_store = MockVectorStore(
156-
repo_path=str(feature_store.repo_path),
156+
repo_path=str(simple_feature_store.repo_path),
157157
rag_view=document_embeddings,
158158
features=[
159159
"document_embeddings:content",
@@ -181,13 +181,13 @@ def test_feast_index_initialization():
181181
assert index is not None
182182

183183

184-
def test_feast_index_methods():
185-
"""Test FeastIndex methods raise NotImplementedError."""
186-
index = FeastIndex()
187-
with pytest.raises(NotImplementedError):
188-
index.get_top_docs(np.array([1, 2, 3]))
189-
with pytest.raises(NotImplementedError):
190-
index.get_doc_dicts(["doc1", "doc2"])
184+
# def test_feast_index_methods():
185+
# """Test FeastIndex methods raise NotImplementedError."""
186+
# index = FeastIndex()
187+
# with pytest.raises(NotImplementedError):
188+
# index.get_top_docs(np.array([1, 2, 3]))
189+
# with pytest.raises(NotImplementedError):
190+
# index.get_doc_dicts(["doc1", "doc2"])
191191

192192

193193
def test_rag_retriever_initialization(rag_retriever):
@@ -200,7 +200,7 @@ def test_rag_retriever_initialization(rag_retriever):
200200
assert rag_retriever.format_document is not None # Should have default formatter
201201

202202

203-
def test_rag_retriever_custom_format_document(feature_store):
203+
def test_rag_retriever_custom_format_document(simple_feature_store):
204204
"""Test RAG retriever initialization with custom document formatter."""
205205
from tests.example_repos.example_feature_repo_1 import document_embeddings
206206

@@ -212,7 +212,7 @@ def custom_formatter(doc):
212212
question_encoder=MockModel(),
213213
generator_tokenizer=MockTokenizer(),
214214
generator_model=MockModel(),
215-
feast_repo_path=str(feature_store.repo_path),
215+
feast_repo_path=str(simple_feature_store.repo_path),
216216
feature_view=document_embeddings,
217217
features=[
218218
"document_embeddings:content",
@@ -247,7 +247,7 @@ def test_default_format_document(rag_retriever):
247247
assert "Embeddings" not in formatted # Vector should be skipped
248248

249249

250-
def test_rag_retriever_invalid_search_type(feature_store):
250+
def test_rag_retriever_invalid_search_type(simple_feature_store):
251251
"""Test RAG retriever initialization with invalid search type."""
252252
from tests.example_repos.example_feature_repo_1 import (
253253
document_embeddings,
@@ -259,7 +259,7 @@ def test_rag_retriever_invalid_search_type(feature_store):
259259
question_encoder=MockModel(),
260260
generator_tokenizer=MockTokenizer(),
261261
generator_model=MockModel(),
262-
feast_repo_path=str(feature_store.repo_path),
262+
feast_repo_path=str(simple_feature_store.repo_path),
263263
feature_view=document_embeddings,
264264
features=["content", "title", "Embeddings"],
265265
search_type="invalid",
@@ -437,8 +437,10 @@ def test_retrieve_documents(rag_retriever):
437437
# End-to-end functionality test
438438
def test_generate_answer(rag_retriever):
439439
"""Test generating an answer using the RAG retriever."""
440-
# Mock the retrieve method
441-
rag_retriever.retrieve = Mock(
440+
# Mock the retrieve method using patch
441+
with patch.object(
442+
rag_retriever,
443+
"retrieve",
442444
return_value=(
443445
np.array([[[0.1] * 8, [0.2] * 8]]), # 8-dimensional embeddings
444446
np.array([[1, 2]]),
@@ -449,28 +451,29 @@ def test_generate_answer(rag_retriever):
449451
"title": ["Doc 1", "Doc 2"],
450452
}
451453
],
454+
),
455+
) as mock_retrieve:
456+
# Mock the generator model's generate method
457+
rag_retriever.generator_model.generate = Mock(
458+
return_value=torch.tensor([[1, 2, 3]])
452459
)
453-
)
454-
455-
# Mock the generator model's generate method
456-
rag_retriever.generator_model.generate = Mock(
457-
return_value=torch.tensor([[1, 2, 3]])
458-
)
459460

460-
# Generate an answer
461-
answer = rag_retriever.generate_answer("test query", top_k=2, max_new_tokens=100)
461+
# Generate an answer
462+
answer = rag_retriever.generate_answer(
463+
"test query", top_k=2, max_new_tokens=100
464+
)
462465

463-
# Verify the answer
464-
assert isinstance(answer, str)
465-
assert len(answer) > 0
466+
# Verify the answer
467+
assert isinstance(answer, str)
468+
assert len(answer) > 0
466469

467-
# Verify that retrieve was called with correct parameters
468-
rag_retriever.retrieve.assert_called_once()
469-
call_args = rag_retriever.retrieve.call_args[1]
470-
assert call_args["n_docs"] == 2
471-
assert call_args["query"] == "test query"
470+
# Verify that retrieve was called with correct parameters
471+
mock_retrieve.assert_called_once()
472+
call_args = mock_retrieve.call_args[1]
473+
assert call_args["n_docs"] == 2
474+
assert call_args["query"] == "test query"
472475

473-
# Verify that generate was called with correct parameters
474-
rag_retriever.generator_model.generate.assert_called_once()
475-
call_args = rag_retriever.generator_model.generate.call_args[1]
476-
assert call_args["max_new_tokens"] == 100
476+
# Verify that generate was called with correct parameters
477+
rag_retriever.generator_model.generate.assert_called_once()
478+
call_args = rag_retriever.generator_model.generate.call_args[1]
479+
assert call_args["max_new_tokens"] == 100

sdk/python/tests/unit/test_vector_store.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def cleanup_milvus():
1919
time.sleep(0.1)
2020

2121

22-
def test_vector_store_initialization(feature_store):
22+
def test_vector_store_initialization(example_feature_store):
2323
"""Test vector store initialization."""
2424
print("Testing vector store initialization...")
2525

2626
# Apply the feature view first
27-
feature_store.apply([document_embeddings])
27+
example_feature_store.apply([document_embeddings])
2828

29-
doc_view = feature_store.get_feature_view("document_embeddings")
29+
doc_view = example_feature_store.get_feature_view("document_embeddings")
3030
store = MockVectorStore(
31-
repo_path=str(feature_store.repo_path),
31+
repo_path=str(example_feature_store.repo_path),
3232
rag_view=doc_view,
3333
features=[
3434
"document_embeddings:content",
@@ -44,16 +44,16 @@ def test_vector_store_initialization(feature_store):
4444
]
4545

4646

47-
def test_vector_store_query(feature_store):
47+
def test_vector_store_query(example_feature_store):
4848
"""Test vector store query method."""
4949
print("Testing vector store query...")
5050

5151
# Apply the feature view first
52-
feature_store.apply([document_embeddings])
52+
example_feature_store.apply([document_embeddings])
5353

54-
doc_view = feature_store.get_feature_view("document_embeddings")
54+
doc_view = example_feature_store.get_feature_view("document_embeddings")
5555
store = MockVectorStore(
56-
repo_path=str(feature_store.repo_path),
56+
repo_path=str(example_feature_store.repo_path),
5757
rag_view=doc_view,
5858
features=[
5959
"document_embeddings:content",

sdk/python/tests/utils/rag_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_test_data():
7171

7272

7373
@pytest.fixture
74-
def feature_store():
74+
def example_feature_store():
7575
"""Create a feature store using example repo."""
7676
runner = CliRunner()
7777
# Patch the run method to always succeed for teardown

0 commit comments

Comments
 (0)