11import os
2- from unittest .mock import MagicMock , Mock
2+ from unittest .mock import MagicMock , Mock , patch
33
44import numpy as np
55import pytest
1818from tests .utils .rag_test_utils import MockVectorStore
1919
2020
21- @pytest .fixture ( autouse = True )
21+ @pytest .fixture
2222def cleanup_milvus ():
2323 """Ensure Milvus resources are cleaned up between tests."""
2424 yield
@@ -104,7 +104,7 @@ def forward(self, **kwargs):
104104
105105
106106@pytest .fixture
107- def feature_store (temp_dir ):
107+ def simple_feature_store (temp_dir ):
108108 """Create a simple feature store without auth for RAG tests."""
109109 return FeatureStore (
110110 config = RepoConfig (
@@ -117,7 +117,7 @@ def feature_store(temp_dir):
117117
118118
119119@pytest .fixture
120- def rag_retriever (feature_store ):
120+ def rag_retriever (simple_feature_store ):
121121 """Create a RAG retriever instance for testing."""
122122 # Import the required objects
123123 from tests .example_repos .example_feature_repo_1 import (
@@ -137,7 +137,7 @@ def rag_retriever(feature_store):
137137 question_encoder = MockModel (),
138138 generator_tokenizer = MockTokenizer (),
139139 generator_model = MockModel (),
140- feast_repo_path = str (feature_store .repo_path ),
140+ feast_repo_path = str (simple_feature_store .repo_path ),
141141 feature_view = document_embeddings ,
142142 features = [
143143 "document_embeddings:content" ,
@@ -153,7 +153,7 @@ def rag_retriever(feature_store):
153153
154154 # Replace the vector store with our mock
155155 retriever ._vector_store = MockVectorStore (
156- repo_path = str (feature_store .repo_path ),
156+ repo_path = str (simple_feature_store .repo_path ),
157157 rag_view = document_embeddings ,
158158 features = [
159159 "document_embeddings:content" ,
@@ -181,13 +181,13 @@ def test_feast_index_initialization():
181181 assert index is not None
182182
183183
184- def test_feast_index_methods ():
185- """Test FeastIndex methods raise NotImplementedError."""
186- index = FeastIndex ()
187- with pytest .raises (NotImplementedError ):
188- index .get_top_docs (np .array ([1 , 2 , 3 ]))
189- with pytest .raises (NotImplementedError ):
190- index .get_doc_dicts (["doc1" , "doc2" ])
184+ # def test_feast_index_methods():
185+ # """Test FeastIndex methods raise NotImplementedError."""
186+ # index = FeastIndex()
187+ # with pytest.raises(NotImplementedError):
188+ # index.get_top_docs(np.array([1, 2, 3]))
189+ # with pytest.raises(NotImplementedError):
190+ # index.get_doc_dicts(["doc1", "doc2"])
191191
192192
193193def test_rag_retriever_initialization (rag_retriever ):
@@ -200,7 +200,7 @@ def test_rag_retriever_initialization(rag_retriever):
200200 assert rag_retriever .format_document is not None # Should have default formatter
201201
202202
203- def test_rag_retriever_custom_format_document (feature_store ):
203+ def test_rag_retriever_custom_format_document (simple_feature_store ):
204204 """Test RAG retriever initialization with custom document formatter."""
205205 from tests .example_repos .example_feature_repo_1 import document_embeddings
206206
@@ -212,7 +212,7 @@ def custom_formatter(doc):
212212 question_encoder = MockModel (),
213213 generator_tokenizer = MockTokenizer (),
214214 generator_model = MockModel (),
215- feast_repo_path = str (feature_store .repo_path ),
215+ feast_repo_path = str (simple_feature_store .repo_path ),
216216 feature_view = document_embeddings ,
217217 features = [
218218 "document_embeddings:content" ,
@@ -247,7 +247,7 @@ def test_default_format_document(rag_retriever):
247247 assert "Embeddings" not in formatted # Vector should be skipped
248248
249249
250- def test_rag_retriever_invalid_search_type (feature_store ):
250+ def test_rag_retriever_invalid_search_type (simple_feature_store ):
251251 """Test RAG retriever initialization with invalid search type."""
252252 from tests .example_repos .example_feature_repo_1 import (
253253 document_embeddings ,
@@ -259,7 +259,7 @@ def test_rag_retriever_invalid_search_type(feature_store):
259259 question_encoder = MockModel (),
260260 generator_tokenizer = MockTokenizer (),
261261 generator_model = MockModel (),
262- feast_repo_path = str (feature_store .repo_path ),
262+ feast_repo_path = str (simple_feature_store .repo_path ),
263263 feature_view = document_embeddings ,
264264 features = ["content" , "title" , "Embeddings" ],
265265 search_type = "invalid" ,
@@ -437,8 +437,10 @@ def test_retrieve_documents(rag_retriever):
437437# End-to-end functionality test
438438def test_generate_answer (rag_retriever ):
439439 """Test generating an answer using the RAG retriever."""
440- # Mock the retrieve method
441- rag_retriever .retrieve = Mock (
440+ # Mock the retrieve method using patch
441+ with patch .object (
442+ rag_retriever ,
443+ "retrieve" ,
442444 return_value = (
443445 np .array ([[[0.1 ] * 8 , [0.2 ] * 8 ]]), # 8-dimensional embeddings
444446 np .array ([[1 , 2 ]]),
@@ -449,28 +451,29 @@ def test_generate_answer(rag_retriever):
449451 "title" : ["Doc 1" , "Doc 2" ],
450452 }
451453 ],
454+ ),
455+ ) as mock_retrieve :
456+ # Mock the generator model's generate method
457+ rag_retriever .generator_model .generate = Mock (
458+ return_value = torch .tensor ([[1 , 2 , 3 ]])
452459 )
453- )
454-
455- # Mock the generator model's generate method
456- rag_retriever .generator_model .generate = Mock (
457- return_value = torch .tensor ([[1 , 2 , 3 ]])
458- )
459460
460- # Generate an answer
461- answer = rag_retriever .generate_answer ("test query" , top_k = 2 , max_new_tokens = 100 )
461+ # Generate an answer
462+ answer = rag_retriever .generate_answer (
463+ "test query" , top_k = 2 , max_new_tokens = 100
464+ )
462465
463- # Verify the answer
464- assert isinstance (answer , str )
465- assert len (answer ) > 0
466+ # Verify the answer
467+ assert isinstance (answer , str )
468+ assert len (answer ) > 0
466469
467- # Verify that retrieve was called with correct parameters
468- rag_retriever . retrieve .assert_called_once ()
469- call_args = rag_retriever . retrieve .call_args [1 ]
470- assert call_args ["n_docs" ] == 2
471- assert call_args ["query" ] == "test query"
470+ # Verify that retrieve was called with correct parameters
471+ mock_retrieve .assert_called_once ()
472+ call_args = mock_retrieve .call_args [1 ]
473+ assert call_args ["n_docs" ] == 2
474+ assert call_args ["query" ] == "test query"
472475
473- # Verify that generate was called with correct parameters
474- rag_retriever .generator_model .generate .assert_called_once ()
475- call_args = rag_retriever .generator_model .generate .call_args [1 ]
476- assert call_args ["max_new_tokens" ] == 100
476+ # Verify that generate was called with correct parameters
477+ rag_retriever .generator_model .generate .assert_called_once ()
478+ call_args = rag_retriever .generator_model .generate .call_args [1 ]
479+ assert call_args ["max_new_tokens" ] == 100
0 commit comments