How to Enhance RAG Performance with CRAG?

Sunil Kumar Dash 14 Mar, 2024 • 11 min read

Introduction

In this article we will learn to enhance RAG performance with CRAG. The word RAG has been floating around for a while and for all the good reasons. Large language models made it possible to build solutions for problems that were difficult before. Question and Answering over large amounts of data was one such problem. But now, it is possible. Thanks to LLMs, AI frameworks, and other tools such as vector databases.

Instead of only matching keywords and metadata to find similar texts, we can use cosine similarity between texts to retrieve relevant matches. And use the matched text chunks to generate a coherent answer from an LLM. This method is called RAG(Retrieval Augmented Generation). But is vector retrieval always sufficient? Can we rely on RAG when the retrieved object does not have answers to the question? This is where CRAG, or Corrective Retrieval Augmented Generation, comes into the picture.

Learning Objectives

  • Learn about the limitations of RAG.
  • Understand what CRAG is and how it improves RAG.
  • Learn about LangGraph, a library for building RAG apps as Graphs.
  • Use LangGraph to implement CRAG.

This article was published as a part of the Data Science Blogathon.

What Are the Limitations of RAG?

RAG has been great for questioning and answering over-text documents. It is a straightforward process. We extract the contents from documents, pre-process them, find embeddings, and store them in a vector database. We then compute the similarity score between the queries and text documents to find the most semantically similar text chunks. These chunks are then fed to an LLM to generate a human-readable answer.

RAG

This is simple yet effective for most use cases. However, it is not always effective. Finding relevant documents using just cosine similarity may not always be ideal. Throwing in top k text chunks to generate an answer may not be a good idea where the cost of false information is high. 

To mitigate this, the primary knowledge sources can be supplemented with external sources like the web. It has been observed that web access can enhance the LLM capability for QA.  Much of the success of Bard(Gemini Pro) and Perplexity AI is due to web integration with LLMs.

Observe the performance gap between Gemini Pro with web and vanilla Gemini Pro in the LMSys chatbot leaderboard.

RAG Performance with CRAG

The Corrective RAG is based on the same principle. It introduces the internet as a third source of knowledge, supplementing primary knowledge bases. So, let’s understand how it works.

What is CRAG?

The word corrective in CRAG stands for a corrective module in the existing RAG pipeline. This corrective module is responsible for correcting the wrong retrieval results. The idea was proposed in the paper Corrective Retrieval Augmented Generation. The paper describes how to build a CRAG system with all the benchmarks. So, let’s see the fundamental architecture of CRAG.

RAG Performance with CRAG

As you can observe, there are three new additions to a conventional RAG architecture: an evaluator, knowledge refinement, and knowledge searching.

Evaluator

The evaluator is a language model responsible for classifying a retrieved text as correct, incorrect, or ambiguous. The authors have used a fine-tuned T5 large model as the evaluator, but any LLM can be used. The LLM is queried with the question and a retrieved text chunk to validate if the chunk is relevant or not. The texts are then classified as correct, incorrect, or ambiguous. The accuracy of the evaluator plays a crucial role here.

Knowledge Refinement

Once the chunks are classified as correct, they undergo further pruning for a refined source of knowledge. The text chunks are decomposed into small knowledge strips(1-2 sentences), and an evaluator is used again to filter out irrelevant strips. The final strips are rejoined again and sent to the LLM for answer generation.

Knowledge Searching

This is applied when a chunk is classified as either ambiguous or incorrect. When a chunk is found to be irrelevant, we discard it and use a web search API to find relevant outcomes from the internet. So, instead of using the incorrect chunks, we use the sources from the internet for final answer generation.

However, in case of ambiguity, we apply both the knowledge refinement and search. The irrelevant strips are weeded out, and new information from the internet is added. Final concatenated chunks are sent to the LLM for answer generation.

This approach of using an evaluator, knowledge refinement, and search can significantly improve the  RAG performance of QA systems.

Now that we understand the concepts behind CRAG let’s implement them with LangGraph.

What is LangGraph?

LangGraph is an extension of the LangChain ecosystem. LangGraph allows us to build AI apps, including agents and RAG, as a graph. It treats the workflows as a cyclic Graph structure, where each node represents a function or a Langchain Runnable object, and edges are connections between nodes. It also provides a stateful solution where a global state object can be shared among nodes.

LangGraph’s main features include:

  • Nodes: Any function or Langchain Runnable object like a tool.
  • Edges: Defines the direction between nodes.
  • Stateful Graphs: The primary type of graph. It is designed to manage and update state objects as it processes data through its nodes.

LangGraph leverages this to facilitate a cyclic LLM call execution with state persistence, which is crucial for agentic behavior. The architecture derives inspiration from Pregel and Apache Beam. 

We will use the LangGraph to build our Corrective RAG pipeline.

How to Implement CRAG with LangGraph?

Let’s understand the structure of our pipeline. We will build a CRAG pipeline, but for brevity, instead of using three evaluator classes, we will only use two. A chunk is either relevant or irrelevant. As the evaluator, we will use Mixtral 8x7b from Together AI. You can use a re-ranker like Cohere re-rank as the evaluator. The Cohere re-ranker outputs relevant documents and their relevancy score in decreasing order. This can be used to classify documents with some thresholds for each category.

We will use the Tavily search API for web searching for irrelevant chunks. Get APIs of both Together and Tavily before moving ahead. Also, the same Mixtral model will be used as the final LLM for answer generation. You can use other LLMs like Gemini, GPTs, Mistral medium, etc.

This is our workflow.

CRAG with LangGraph

How to Set-up Dev Environment?

Create a Python virtual environment and install the following libraries.

! pip install --quiet langchain_community langchain-openai langchainhub chromadb \
langchain langgraph tavily-python sentence-transformers

Now, set up API keys for Together and Tavily as environment variables.

import os

os.environ["TOGETHER_API_KEY"] = "Your Key"
os.environ["TAVILY_API_KEY"] = "Your Key"

Import the libraries.

import json
import operator
from typing import Annotated, Sequence, TypedDict

from langchain import hub
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai.chat_models import ChatOpenAI

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings

How to Prepare Document?

In this step, we will use one of my blog posts as the document and use LangChain’s tool for loading texts from the web page. We will use LangChain’s recursive text splitter to split documents and index them in a Chroma database. We use the BAAI/bge-base-en-v1.5 from the sentence transformers library as the embedding model. You can use any other model you wish.

# Load

url = "https://www.analyticsvidhya.com/blog/2023/10/introduction-to-hnsw-hierarchical-/
navigable-small-world/"
loader = WebBaseLoader(url)
docs = loader.load()

# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=100
)
all_splits = text_splitter.split_documents(docs)

# Embed and index
embedding = SentenceTransformerEmbeddings(model_name="BAAI/bge-base-en-v1.5")

# Index
vectorstore = Chroma.from_documents(
    documents=all_splits,
    collection_name="rag-chroma",
    embedding=embedding,
)
retriever = vectorstore.as_retriever()

Define the LLM you will use. As discussed before, we will use a fine-tuned version of Mixtral from Nous Labs with TogetherAI.

TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY")
llm = ChatOpenAI(base_url="https://api.together.xyz/v1",
                 api_key=TOGETHER_API_KEY,
                 model = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")

As Together API supports OpenAI SDK, all that changed was the base URL, API_KEY, and model name.

How to Define Nodes?

As mentioned earlier, LangGraph implements a graph structure for building applications on top of it. Also, we know it lets us use a state object for sharing data between nodes. So, let’s define the state class.

from typing import Annotated, Dict, TypedDict

from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        keys: A dictionary where each key is a string.
    """

    keys: Dict[str, any]

The GraphState is a TypedDict class with a single attribute “key”, it is a dictionary that will store all the downstream data that we will need after each node.

Retrieve

We will now create the first node of our graph structure. As we know, the nodes in LangGraph are any functions or tools. The first node of our pipeline will be the retriever, responsible for retrieving documents from vector data.

def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {"documents": documents, "question": question}}

Grade Documents

The next node we will work on is for grading. We will use the LLM defined earlier to grade each chunk as “yes” or “no.” If a chunk is irrelevant, we will set a state key “search” as True.

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with relevant documents
    """

    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    prompt = PromptTemplate(
        template="""You are a grader assessing the relevance of a retrieved 
        document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keywords related to the user question, 
        grade it as relevant. \n
        It does not need to be a stringent test. The goal is to filter out 
        erroneous retrievals. \n
        Give a binary score of 'yes' or 'no' score to indicate whether the document 
        is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no preamble 
        or explanation.
        """,
        input_variables=["question", "context"],
    )

    chain = prompt | llm | JsonOutputParser()

    # Score
    filtered_docs = []
    search = "No"  # Default does not opt for web search to supplement retrieval
    for d in documents:
        score = chain.invoke(
            {
                "question": question,
                "context": d.page_content,
            }
        )
        grade = score["score"]
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            search = "Yes"  # Perform web search
            continue

    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search,
        }
    }

In the above code, the chain was defined using Langchain Query Language, which means the prompt was passed to the LLM, and subsequently, the LLM outcome was passed to a JSON output parser.

Query Rewriting

The queries need to be re-written before sending it to the search API. This is done to increase the chances of better web search results.

def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Create a prompt template with format instructions and the query
    prompt = PromptTemplate(
        template="""You are generating questions that is well optimized for retrieval. \n 
        Look at the input and try to reason about the underlying sematic intent / meaning. \n 
        Here is the initial question:
        \n ------- \n
        {question} 
        \n ------- \n
        Provide an improved question without any premable, only respond with the 
        updated question: """,
        input_variables=["question"],
    )
    # Prompt
    chain = prompt | llm | StrOutputParser()
    
    better_question = chain.invoke({"question": question})

    return {
        "keys": {"documents": documents, "question": better_question,}
    }

In this node, we will define a function that uses the Tavily API to fetch the top K results from a web search. The search results are concatenated and appended to the documents list before being sent to the generation node.

def web_search(state):
    """
    Web search based on the re-phrased question using Tavily API.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Web results appended to documents.
    """

    print("---WEB SEARCH---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    tool = TavilySearchResults()
    docs = tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    print(web_results)
    documents.append(web_results)

    return {"keys": {"documents": documents, "question": question}}

LLM Generation

In this node, the documents are sent to the LLM along with the query, and the output is added to the state dictionary.

def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains generation
    """
    print("---GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }

We have defined all the nodes that we need. Now, we can define the workflow and add nodes to it.

import pprint

from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("web_search", web_search)  # web search

How to Define Edges?

We are done with nodes now; we need to define the edges. The edges signal the direction of workflows. In LangGraph, there are two types of edges.

  • Conditional: A conditional edge where the next node of the workflow is chosen based on the condition
  • Non-conditional: These are regular edges that connect one node to another.

In our case, we need a conditional edge between the grading node and the generation node. If the documents are relevant, we run the generation node else, the transform query node.

def decide_to_generate(state):
    """
    Determines whether to generate an answer or re-generate a question for web search.

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Next node to call
    """

    print("---DECIDE TO GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    search = state_dict["run_web_search"]

    if search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

Now connect the respective nodes and set the entry point. This is the node from where the workflow starts.

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

How to Run the Workflow?

Finally, compile the workflow and run it by passing a query. 

# Compile
app = workflow.compile()

# Run
inputs = {
    "keys": {
        "question": "Who is the author of the HNSW paper?",
    }
}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
       
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value["keys"]["generation"])

The article does not directly mention the Author of the HNSW paper. Hence, the retriever could not retrieve any relevant text chunks from the vector store. But this is a trivial question, and the RAG would have failed to address it. However, with CRAG, this was not a problem as we could search the web in case of irrelevant documents.

Conclusion

The implementation of CRAG presents a pivotal enhancement to RAG, effectively addressing its inherent gaps by incorporating the internet as a third knowledge source. This article thoroughly explores CRAG and its implementation, offering valuable insights into how this augmentation fortifies the conventional RAG pipeline. Through this examination, we highlight key takeaways for optimizing knowledge augmentation, demonstrating how CRAG significantly boosts RAG performance with its internet integration.

Key Takeaways

  • The traditional RAG approach of retrieving and throwing documents to an LLM may not always work.
  • CRAG stands for Corrective Retrieval Augmented Generation.
  • It improves traditional RAG by adding an evaluator, knowledge refining, and knowledge search steps to the pipeline.
  • In CRAG, an LLM is used as an evaluator to distill relevant retrieved chunks; the chunks are then pruned into smaller strips to weed out irrelevant knowledge strips.
  • A web search system is used to supplement retrieved documents if the chunks are not reliable.
  • Finally, the documents and/or web sources are sent to an LLM for answer generation.

Frequently Asked Question

Q1. What is LangGraph?

A. LangGraph is an open-source library for building stateful cyclic multi-actor agent systems. It is built on top of the LangChain eco-system.

Q2. What is RAG?

A. RAG stands for Retrieval Augmented Generation. In RAG, the documents are split and stored in a vector database. These documents are then matched with embeddings of user queries, and top-k retrieved chunks are sent to an LLM for answer generation.

Q3. What is the difference between CRAG and RAG?

A. Corrective RAG uses an evaluator LLM to distill relevant documents from all the retrieved documents and, if needed, uses external knowledge sources to supplant answer generation.

Q4. When to use LangGraph over LangChain?

A. LangGraph is preferred for building cyclic multi-actor agents, while LangChain is better at creating chains or directed acyclic systems.

Q5. What is a RAG pipeline?

A. A RAG pipeline retrieves documents from external data stores, processes them to store them in a knowledge base, and provides tools to query them.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

Sunil Kumar Dash 14 Mar 2024

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers