Skip to content

Commit 04fd031

Browse files
committed
Addressing feedback and fixing tests
Signed-off-by: Fiona Waters <fiwaters6@gmail.com>
1 parent 455e3e5 commit 04fd031

File tree

5 files changed

+51
-71
lines changed

5 files changed

+51
-71
lines changed

sdk/python/feast/rag_retriever.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@ def __init__(self):
3333
"""Initialize the Feast index."""
3434
pass
3535

36-
def get_top_docs(self, query_vectors: np.ndarray, n_docs: int = 5):
37-
"""Get top documents (not implemented).
38-
39-
This method is required by the RagRetriever interface but is not used
40-
as we override the retrieve method in FeastRAGRetriever.
41-
"""
42-
raise NotImplementedError("get_top_docs is not yet implemented.")
43-
44-
def get_doc_dicts(self, doc_ids: List[str]):
45-
"""Get document dictionaries (not implemented).
46-
47-
This method is required by the RagRetriever interface but is not used
48-
as we override the retrieve method in FeastRAGRetriever.
49-
"""
50-
raise NotImplementedError("get_doc_dicts is not yet implemented.")
51-
5236

5337
class FeastRAGRetriever(RagRetriever):
5438
"""RAG retriever implementation that uses Feast as a backend."""
@@ -114,12 +98,6 @@ def __init__(
11498
**kwargs,
11599
)
116100

117-
# if torch.cuda.is_available():
118-
# self.question_encoder.to(torch.device("cuda"))
119-
# self.generator_model.to(torch.device("cuda"))
120-
# self.question_encoder = question_encoder
121-
# self.generator_model = generator_model
122-
# self.generator_tokenizer = generator_tokenizer
123101
self.feast_repo_path = feast_repo_path
124102
self.search_type = search_type.lower()
125103
self.format_document = format_document or self._default_format_document

sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def test_dynamodb_online_store_update(repo_config, dynamodb_online_store):
198198
assert existing_tables is not None
199199
assert len(existing_tables) == 1
200200
assert existing_tables[0] == f"test_aws.{db_table_keep_name}"
201-
202201
assert _get_tags(dynamodb_client, existing_tables[0]) == [
203202
{"Key": "some", "Value": "tag"}
204203
]

sdk/python/tests/unit/test_rag_retriever.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from unittest.mock import MagicMock, Mock
2+
from unittest.mock import MagicMock, Mock, patch
33

44
import numpy as np
55
import pytest
@@ -18,7 +18,7 @@
1818
from tests.utils.rag_test_utils import MockVectorStore
1919

2020

21-
@pytest.fixture(autouse=True)
21+
@pytest.fixture
2222
def cleanup_milvus():
2323
"""Ensure Milvus resources are cleaned up between tests."""
2424
yield
@@ -104,7 +104,7 @@ def forward(self, **kwargs):
104104

105105

106106
@pytest.fixture
107-
def feature_store(temp_dir):
107+
def simple_feature_store(temp_dir):
108108
"""Create a simple feature store without auth for RAG tests."""
109109
return FeatureStore(
110110
config=RepoConfig(
@@ -117,7 +117,7 @@ def feature_store(temp_dir):
117117

118118

119119
@pytest.fixture
120-
def rag_retriever(feature_store):
120+
def rag_retriever(simple_feature_store):
121121
"""Create a RAG retriever instance for testing."""
122122
# Import the required objects
123123
from tests.example_repos.example_feature_repo_1 import (
@@ -137,7 +137,7 @@ def rag_retriever(feature_store):
137137
question_encoder=MockModel(),
138138
generator_tokenizer=MockTokenizer(),
139139
generator_model=MockModel(),
140-
feast_repo_path=str(feature_store.repo_path),
140+
feast_repo_path=str(simple_feature_store.repo_path),
141141
feature_view=document_embeddings,
142142
features=[
143143
"document_embeddings:content",
@@ -153,7 +153,7 @@ def rag_retriever(feature_store):
153153

154154
# Replace the vector store with our mock
155155
retriever._vector_store = MockVectorStore(
156-
repo_path=str(feature_store.repo_path),
156+
repo_path=str(simple_feature_store.repo_path),
157157
rag_view=document_embeddings,
158158
features=[
159159
"document_embeddings:content",
@@ -181,13 +181,13 @@ def test_feast_index_initialization():
181181
assert index is not None
182182

183183

184-
def test_feast_index_methods():
185-
"""Test FeastIndex methods raise NotImplementedError."""
186-
index = FeastIndex()
187-
with pytest.raises(NotImplementedError):
188-
index.get_top_docs(np.array([1, 2, 3]))
189-
with pytest.raises(NotImplementedError):
190-
index.get_doc_dicts(["doc1", "doc2"])
184+
# def test_feast_index_methods():
185+
# """Test FeastIndex methods raise NotImplementedError."""
186+
# index = FeastIndex()
187+
# with pytest.raises(NotImplementedError):
188+
# index.get_top_docs(np.array([1, 2, 3]))
189+
# with pytest.raises(NotImplementedError):
190+
# index.get_doc_dicts(["doc1", "doc2"])
191191

192192

193193
def test_rag_retriever_initialization(rag_retriever):
@@ -200,7 +200,7 @@ def test_rag_retriever_initialization(rag_retriever):
200200
assert rag_retriever.format_document is not None # Should have default formatter
201201

202202

203-
def test_rag_retriever_custom_format_document(feature_store):
203+
def test_rag_retriever_custom_format_document(simple_feature_store):
204204
"""Test RAG retriever initialization with custom document formatter."""
205205
from tests.example_repos.example_feature_repo_1 import document_embeddings
206206

@@ -212,7 +212,7 @@ def custom_formatter(doc):
212212
question_encoder=MockModel(),
213213
generator_tokenizer=MockTokenizer(),
214214
generator_model=MockModel(),
215-
feast_repo_path=str(feature_store.repo_path),
215+
feast_repo_path=str(simple_feature_store.repo_path),
216216
feature_view=document_embeddings,
217217
features=[
218218
"document_embeddings:content",
@@ -247,7 +247,7 @@ def test_default_format_document(rag_retriever):
247247
assert "Embeddings" not in formatted # Vector should be skipped
248248

249249

250-
def test_rag_retriever_invalid_search_type(feature_store):
250+
def test_rag_retriever_invalid_search_type(simple_feature_store):
251251
"""Test RAG retriever initialization with invalid search type."""
252252
from tests.example_repos.example_feature_repo_1 import (
253253
document_embeddings,
@@ -259,7 +259,7 @@ def test_rag_retriever_invalid_search_type(feature_store):
259259
question_encoder=MockModel(),
260260
generator_tokenizer=MockTokenizer(),
261261
generator_model=MockModel(),
262-
feast_repo_path=str(feature_store.repo_path),
262+
feast_repo_path=str(simple_feature_store.repo_path),
263263
feature_view=document_embeddings,
264264
features=["content", "title", "Embeddings"],
265265
search_type="invalid",
@@ -437,8 +437,10 @@ def test_retrieve_documents(rag_retriever):
437437
# End-to-end functionality test
438438
def test_generate_answer(rag_retriever):
439439
"""Test generating an answer using the RAG retriever."""
440-
# Mock the retrieve method
441-
rag_retriever.retrieve = Mock(
440+
# Mock the retrieve method using patch
441+
with patch.object(
442+
rag_retriever,
443+
"retrieve",
442444
return_value=(
443445
np.array([[[0.1] * 8, [0.2] * 8]]), # 8-dimensional embeddings
444446
np.array([[1, 2]]),
@@ -449,28 +451,29 @@ def test_generate_answer(rag_retriever):
449451
"title": ["Doc 1", "Doc 2"],
450452
}
451453
],
454+
),
455+
) as mock_retrieve:
456+
# Mock the generator model's generate method
457+
rag_retriever.generator_model.generate = Mock(
458+
return_value=torch.tensor([[1, 2, 3]])
452459
)
453-
)
454-
455-
# Mock the generator model's generate method
456-
rag_retriever.generator_model.generate = Mock(
457-
return_value=torch.tensor([[1, 2, 3]])
458-
)
459460

460-
# Generate an answer
461-
answer = rag_retriever.generate_answer("test query", top_k=2, max_new_tokens=100)
461+
# Generate an answer
462+
answer = rag_retriever.generate_answer(
463+
"test query", top_k=2, max_new_tokens=100
464+
)
462465

463-
# Verify the answer
464-
assert isinstance(answer, str)
465-
assert len(answer) > 0
466+
# Verify the answer
467+
assert isinstance(answer, str)
468+
assert len(answer) > 0
466469

467-
# Verify that retrieve was called with correct parameters
468-
rag_retriever.retrieve.assert_called_once()
469-
call_args = rag_retriever.retrieve.call_args[1]
470-
assert call_args["n_docs"] == 2
471-
assert call_args["query"] == "test query"
470+
# Verify that retrieve was called with correct parameters
471+
mock_retrieve.assert_called_once()
472+
call_args = mock_retrieve.call_args[1]
473+
assert call_args["n_docs"] == 2
474+
assert call_args["query"] == "test query"
472475

473-
# Verify that generate was called with correct parameters
474-
rag_retriever.generator_model.generate.assert_called_once()
475-
call_args = rag_retriever.generator_model.generate.call_args[1]
476-
assert call_args["max_new_tokens"] == 100
476+
# Verify that generate was called with correct parameters
477+
rag_retriever.generator_model.generate.assert_called_once()
478+
call_args = rag_retriever.generator_model.generate.call_args[1]
479+
assert call_args["max_new_tokens"] == 100

sdk/python/tests/unit/test_vector_store.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def cleanup_milvus():
1919
time.sleep(0.1)
2020

2121

22-
def test_vector_store_initialization(feature_store):
22+
def test_vector_store_initialization(example_feature_store):
2323
"""Test vector store initialization."""
2424
print("Testing vector store initialization...")
2525

2626
# Apply the feature view first
27-
feature_store.apply([document_embeddings])
27+
example_feature_store.apply([document_embeddings])
2828

29-
doc_view = feature_store.get_feature_view("document_embeddings")
29+
doc_view = example_feature_store.get_feature_view("document_embeddings")
3030
store = MockVectorStore(
31-
repo_path=str(feature_store.repo_path),
31+
repo_path=str(example_feature_store.repo_path),
3232
rag_view=doc_view,
3333
features=[
3434
"document_embeddings:content",
@@ -44,16 +44,16 @@ def test_vector_store_initialization(feature_store):
4444
]
4545

4646

47-
def test_vector_store_query(feature_store):
47+
def test_vector_store_query(example_feature_store):
4848
"""Test vector store query method."""
4949
print("Testing vector store query...")
5050

5151
# Apply the feature view first
52-
feature_store.apply([document_embeddings])
52+
example_feature_store.apply([document_embeddings])
5353

54-
doc_view = feature_store.get_feature_view("document_embeddings")
54+
doc_view = example_feature_store.get_feature_view("document_embeddings")
5555
store = MockVectorStore(
56-
repo_path=str(feature_store.repo_path),
56+
repo_path=str(example_feature_store.repo_path),
5757
rag_view=doc_view,
5858
features=[
5959
"document_embeddings:content",

sdk/python/tests/utils/rag_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_test_data():
7171

7272

7373
@pytest.fixture
74-
def feature_store():
74+
def example_feature_store():
7575
"""Create a feature store using example repo."""
7676
runner = CliRunner()
7777
# Patch the run method to always succeed for teardown

0 commit comments

Comments
 (0)