diff --git a/RAG/examples/advanced_rag/multi_turn_rag/chains.py b/RAG/examples/advanced_rag/multi_turn_rag/chains.py index dd4664969..87ea0f918 100644 --- a/RAG/examples/advanced_rag/multi_turn_rag/chains.py +++ b/RAG/examples/advanced_rag/multi_turn_rag/chains.py @@ -127,15 +127,31 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene # ] # ) - # This is a workaround Prompt Template + logger.info(f"Chat history: {chat_history}") + + conversation_history = [(msg.role, msg.content) for msg in chat_history] + system_message = [("system", prompts.get("multi_turn_rag_template", ""))] + user_message = [("user", "{input}")] + + # Checking if conversation_history is not None and not empty chat_prompt = ChatPromptTemplate.from_messages( - [("user", prompts.get("multi_turn_rag_template") + "User Query: {input}"),] + system_message + conversation_history + user_message + ) if conversation_history else ChatPromptTemplate.from_messages( + system_message + user_message ) + logger.info(f"Formulated chat prompt template: {chat_prompt}") + + #chat_prompt = ChatPromptTemplate.from_messages(system_message + user_message) + # This is a workaround Prompt Template + # chat_prompt = ChatPromptTemplate.from_messages( + # [("user", prompts.get("multi_turn_rag_template") + "User Query: {input}"),] + # ) + llm = get_llm(**kwargs) stream_chain = chat_prompt | llm | StrOutputParser() - convstore = create_vectorstore_langchain(document_embedder, collection_name="conv_store") + # convstore = create_vectorstore_langchain(document_embedder, collection_name="conv_store") resp_str = "" # TODO Integrate chat_history @@ -161,15 +177,16 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene } ) - history_chain = RunnableAssign( - { - "history": itemgetter("input") - | convstore.as_retriever( - search_type="similarity_score_threshold", - search_kwargs={"score_threshold": settings.retriever.score_threshold, "k": top_k}, - ) - } - ) + # history_chain = RunnableAssign( + # { + # "history": itemgetter("input") + # | convstore.as_retriever( + # search_type="similarity_score_threshold", + # search_kwargs={"score_threshold": settings.retriever.score_threshold, "k": top_k}, + # ) + # } + # ) + if ranker: logger.info( f"Narrowing the collection from {top_k} results and further narrowing it to {settings.retriever.top_k} with the reranker." @@ -181,17 +198,17 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene ) } ) - history_reranker = RunnableAssign( - { - "history": lambda input: ranker.compress_documents( - query=input['input'], documents=input['history'] - ) - } - ) - - retrieval_chain = context_chain | context_reranker | history_chain | history_reranker + # history_reranker = RunnableAssign( + # { + # "history": lambda input: ranker.compress_documents( + # query=input['input'], documents=input['history'] + # ) + # } + # ) + + retrieval_chain = context_chain | context_reranker #| history_chain | history_reranker else: - retrieval_chain = context_chain | history_chain + retrieval_chain = context_chain #| history_chain # Handling Retrieval failure docs = retrieval_chain.invoke({"input": query}, config={"callbacks": [self.cb_handler]}) if not docs: @@ -210,7 +227,7 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene yield chunk resp_str += chunk - self.save_memory_and_get_output({"input": query, "output": resp_str}, convstore) + #self.save_memory_and_get_output({"input": query, "output": resp_str}, convstore) return chain.stream(query, config={"callbacks": [self.cb_handler]}) @@ -223,9 +240,9 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene {"context": itemgetter("input") | ds.as_retriever(search_kwargs={"k": top_k})} ) - history_chain = RunnableAssign( - {"history": itemgetter("input") | convstore.as_retriever(search_kwargs={"k": top_k})} - ) + # history_chain = RunnableAssign( + # {"history": itemgetter("input") | convstore.as_retriever(search_kwargs={"k": top_k})} + # ) if ranker: logger.info( f"Narrowing the collection from {top_k} results and further narrowing it to {settings.retriever.top_k} with the reranker." @@ -237,17 +254,17 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene ) } ) - history_reranker = RunnableAssign( - { - "history": lambda input: ranker.compress_documents( - query=input['input'], documents=input['history'] - ) - } - ) - - retrieval_chain = context_chain | context_reranker | history_chain | history_reranker + # history_reranker = RunnableAssign( + # { + # "history": lambda input: ranker.compress_documents( + # query=input['input'], documents=input['history'] + # ) + # } + # ) + + retrieval_chain = context_chain | context_reranker #| history_chain | history_reranker else: - retrieval_chain = context_chain | history_chain + retrieval_chain = context_chain #| history_chain # Handling Retrieval failure docs = retrieval_chain.invoke({"input": query}, config={"callbacks": [self.cb_handler]}) @@ -265,7 +282,7 @@ def rag_chain(self, query: str, chat_history: List["Message"], **kwargs) -> Gene yield chunk resp_str += chunk - self.save_memory_and_get_output({"input": query, "output": resp_str}, convstore) + #self.save_memory_and_get_output({"input": query, "output": resp_str}, convstore) return chain.stream(query, config={"callbacks": [self.cb_handler]}) diff --git a/RAG/examples/advanced_rag/multi_turn_rag/prompt.yaml b/RAG/examples/advanced_rag/multi_turn_rag/prompt.yaml index 5e63a6443..ee8147637 100644 --- a/RAG/examples/advanced_rag/multi_turn_rag/prompt.yaml +++ b/RAG/examples/advanced_rag/multi_turn_rag/prompt.yaml @@ -15,8 +15,6 @@ multi_turn_rag_template: | You are a document chatbot. Help the user as they ask questions about documents. User message just asked: {input}\n\n For this, we have retrieved the following potentially-useful info: - Conversation History Retrieved: - {history}\n\n Document Retrieved: {context}\n\n Answer only from retrieved data. Make your response conversational. diff --git a/RAG/examples/local_deploy/docker-compose-vectordb.yaml b/RAG/examples/local_deploy/docker-compose-vectordb.yaml index cd76bc981..4f7a9ecf7 100644 --- a/RAG/examples/local_deploy/docker-compose-vectordb.yaml +++ b/RAG/examples/local_deploy/docker-compose-vectordb.yaml @@ -54,7 +54,7 @@ services: milvus: container_name: milvus-standalone - image: milvusdb/milvus:v2.4.5 + image: milvusdb/milvus:v2.4.15-gpu command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcd:2379 @@ -74,6 +74,13 @@ services: depends_on: - "etcd" - "minio" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] + device_ids: ['${VECTORSTORE_GPU_DEVICE_ID:-0}'] profiles: ["nemo-retriever", "milvus", ""] elasticsearch: diff --git a/RAG/src/chain_server/configuration.py b/RAG/src/chain_server/configuration.py index 2fead6c5d..ba0dde217 100644 --- a/RAG/src/chain_server/configuration.py +++ b/RAG/src/chain_server/configuration.py @@ -40,7 +40,7 @@ class VectorStoreConfig(ConfigWizard): "nprobe", default=16, help_txt="Number of units to query", # IVF Flat milvus ) index_type: str = configfield( - "index_type", default="IVF_FLAT", help_txt="Index of the vector db", # IVF Flat for milvus + "index_type", default="GPU_IVF_FLAT", help_txt="Index of the vector db", # IVF Flat for milvus ) diff --git a/RAG/src/chain_server/utils.py b/RAG/src/chain_server/utils.py index caa8169bb..8986f6075 100644 --- a/RAG/src/chain_server/utils.py +++ b/RAG/src/chain_server/utils.py @@ -314,13 +314,17 @@ def create_vectorstore_langchain(document_embedder: "Embeddings", collection_nam ) elif config.vector_store.name == "milvus": logger.info(f"Using milvus collection: {collection_name}") - # vectorstore url can be updated using environment variable APP_VECTORSTORE_URL, it should be in http://ip:port format + if not collection_name: + collection_name = os.getenv('COLLECTION_NAME', "vector_db") + logger.info(f"Using milvus collection: {collection_name}") url = urlparse(config.vector_store.url) vectorstore = Milvus( document_embedder, connection_args={"host": url.hostname, "port": url.port}, collection_name=collection_name, - auto_id=True, + index_params={"index_type": config.vector_store.index_type, "metric_type": "L2", "nlist": config.vector_store.nlist}, + search_params={"nprobe": config.vector_store.nprobe}, + auto_id = True ) else: raise ValueError(f"{config.vector_store.name} vector database is not supported") diff --git a/RAG/src/rag_playground/requirements.txt b/RAG/src/rag_playground/requirements.txt index c94be2174..e8d22682e 100644 --- a/RAG/src/rag_playground/requirements.txt +++ b/RAG/src/rag_playground/requirements.txt @@ -1,6 +1,6 @@ PyYAML==6.0.1 dataclass-wizard==0.22.3 -gradio==4.13.0 +gradio==4.43.0 jinja2==3.1.3 numpy==1.26.4 opentelemetry-api==1.23.0