.webp&w=3840&q=75)
Graph-Enhanced Retrieval-Augmented Generation (GRAPH-RAG)
The GraphRAG project is designed to address the challenge of efficient document retrieval, processing and query answering by combining vector search, BM25 and knowledge graph traversal. The system allows for extracting contextually relevant information from large sets of documents and organizing it into a graph structure, enabling effective querying and response generation. By leveraging large language models (LLMs) and embedding models, GraphRAG offers a powerful solution for retrieving information, analyzing document content and generating accurate answers in real time.
Project Overview
The GraphRAG project represents a powerful system that uses vector search techniques and knowledge graph exploration to improve document searches and questioning processes. The system divides extensive documents into segments while creating embeddings, which get stored within a FAISS vector store for quick retrieval searches. A knowledge graph uses content similarity to establish document chunk relationships, which allows the chunks to connect into a comprehensive network. The system utilizes a QueryEngine to process queries by conducting graph traversals, which enable the retrieval of highly pertinent information. When the context falls short, the system uses a large language model (LLM) to create the missing answer. The Visualizer tool shows a graphical representation of graph traversal, allowing users to better understand the analysis process. Relevant context-based document understanding produces accurate results and supplies a strong method for textual information retrieval.
Prerequisites
- Python 3.6+: Required for running the project.
- Libraries: Install NetworkX, Matplotlib, FAISS, spaCy, OpenAIEmbeddings, NLTK, LangChain and PyPDFLoader.
- OpenAI API Key: Needed for embeddings and LLM interactions.
- Google Colab / Jupyter Notebook: Recommended for running the system.
- pip: For managing and installing dependencies.
- Basic Python Knowledge: Familiarity with Python libraries like pandas, numpy and scikit-learn.
Approach
GraphRAG retrieves documents and generates responses through processes of document processing, knowledge graph building and query expansion. It initially chunked the documents into small pieces through the use of a text splitter and then embedded them using OpenAIEmbeddings; these embeddings were later stored in a FAISS vector store for efficient similarity-based searching. A knowledge graph is then constructed where the node represents a document chunk and the edge follows the content similarity and shared concepts. The QueryEngine traverses this graph using modified Dijkstra's algorithm to expand the context around the query to ensure the visiting of the most relevant information, which gives the entire answer; otherwise, a final response would be generated by an LLM. The visualizer graphically represents the graph traversal to make the procedure transparent so that clicking any node can show how the system comes up with that answer. This contribution makes the process a comprehensive, rich contextual, yet efficient querying method; it is highly suitable for applications where deep document understanding is required.
Workflow and Methodology
Workflow of GraphRAG System
- Document Loading: Documents are loaded using PyPDFLoader.
- Document Processing: Documents are split into chunks embedded using OpenAIEmbeddings and stored in a FAISS vector store.
- Knowledge Graph Construction: A knowledge graph is built, connecting document chunks based on similarity and shared concepts.
- Query Handling: The QueryEngine retrieves relevant documents and traverses the knowledge graph to expand the context.
- Context Expansion: The engine explores the graph, updates the context and uses an LLM to generate an answer if needed.
- Visualization: The Visualizer shows the graph traversal, highlighting explored nodes and edges.
- Final Answer: The system provides the final answer to the query, generated either from the context or via the LLM.
Methodology
- Chunking and Embedding: The documents are chunked into smaller pieces and embedded into vectors, facilitating comparison and retrieval.
- Graph Construction: The initial document chunk embeddings and their cosine similarity are used to draw meaningful links in the knowledge graph.
- Query Expansion: The system expands the query context while traversing the graph and, thus, retrieves the most relevant information.
- LLM-based Answering: When the context is still insufficient, the system generates the answer with an LLM according to the expanded context.
- Graph Visualization: Matplotlib visualizing the traversal path ensures transparency for how the system arrived at the final answer.
Data Collection and Preparation
Data Collection
For this project, data is collected in the form of text documents, such as PDFs or other text-based formats. The documents can cover a variety of topics and in this case, a sample document named "Understanding_Climate_Change.pdf" is used. The goal is to extract meaningful information from these documents by splitting them into smaller, manageable chunks for further analysis. These documents are processed by the PyPDFLoader or similar document loaders to extract the content, which is then stored as a list of text documents for embedding and further processing.
Data Preparation Workflow
- Collect Documents: Gather text documents (e.g., PDFs).
- Load Text: Extract content using PyPDFLoader.
- Chunk Text: Split documents into smaller chunks with RecursiveCharacterTextSplitter.
- Generate Embeddings: Convert chunks to vector embeddings using OpenAIEmbeddings.
- Store Embeddings: Save embeddings in FAISS for fast retrieval.
- Extract Concepts: Use spaCy and LLMs for named entity and concept extraction.
- Build Knowledge Graph: Link document chunks based on similarity and shared concepts.
- Prepare Data: Organize the data for efficient querying and analysis.
Code Explanation
Mounting Google Drive
This code mounts Google Drive to Colab, allowing access to files stored in Drive. The mounted directory is /content/drive, enabling seamless file handling.
from google.colab import drive
drive.mount('/content/drive')
Installing Necessary Libraries
The code installs Python packages, which are essential for various tasks. The installation process adds langchain-community along with langchain-openai packages for language model usage and rank_bm25 alongside pymupdf and pypdf for PDF processing and Pydantic implements data validation and faiss-cpu enables quick similarity searches on extensive datasets.
!pip install -U langchain-community
!pip install langchain-openai
!pip install rank_bm25
!pip install pymupdf
!pip install pypdf
!pip install pydantic
!pip install faiss-cpu
Changing Directory in Google Colab
The command %cd /content/drive/MyDrive/New 90 Projects/generative_ai_project/GraphRAG changes the working directory to the specified folder in Google Colab.
%cd /content/drive/MyDrive/New 90 Projects/generative_ai_project/GraphRAG: Graph-Enhanced Retrieval-Augmented Generation
Importing Libraries for the Project
The code imports fundamental libraries which are necessary for the project implementation. The program loads natural language processing libraries such as NLTK and spaCy alongside the machine learning tools sklearn and numpy and the graph processing capabilities of networkx. LangChain fuels the system through document loading and embedding operations and retrieval tasks and the application benefits from thread pooling features and environment variable management functionalities.
# Import necessary libraries
import os
import sys
import heapq
import numpy as np
import spacy
import nltk
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import List, Tuple, Dict
from dotenv import load_dotenv
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydantic import BaseModel, Field
from sklearn.metrics.pairwise import cosine_similarity
# LangChain imports
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
# NLTK imports
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
# spaCy setup
from spacy.cli import download
from spacy.lang.en import English
# Set OpenAI API key (Replace YOUR_ACTUAL_API_KEY with your key)
%env OPENAI_API_KEY="ADD YOUR API KEY"
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
DocumentProcessor Class
The DocumentProcessor class serves as an interface for processing documents with embedding generation and document chunk similarity analysis.
__init__: During initialization, DocumentProcessor receives RecursiveCharacterTextSplitter to divide documents into chunks while using an OpenAIEmbeddings instance to embed these chunks.
process_documents: A document list is input into this method, which splits the documents into chunks while generating vectorized FAISS storage from embedded sections. When processing begins, it returns the separated documents together with the generated vector store.
create_embeddings_batch: This method embeds multiple text items in separate batches through OpenAI's framework. The function generates arrays containing embeddings that correspond to the provided input texts.
compute_similarity_matrix: This function computes cosine similarity between embedding vectors of input texts before returning a similarity matrix that depicts text-to-text relationships.
# Define the DocumentProcessor class
class DocumentProcessor:
def __init__(self):
"""
Initializes the DocumentProcessor with a text splitter and OpenAI embeddings.
Attributes:
- text_splitter: An instance of RecursiveCharacterTextSplitter with specified chunk size and overlap.
- embeddings: An instance of OpenAIEmbeddings used for embedding documents.
"""
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.embeddings = OpenAIEmbeddings()
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into smaller chunks and creating a vector store.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- tuple: A tuple containing:
- splits (list of str): The list of split document chunks.
- vector_store (FAISS): A FAISS vector store created from the split document chunks and their embeddings.
"""
splits = self.text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(splits, self.embeddings)
return splits, vector_store
def create_embeddings_batch(self, texts, batch_size=32):
"""
Creates embeddings for a list of texts in batches.
Args:
- texts (list of str): A list of texts to be embedded.
- batch_size (int, optional): The number of texts to process in each batch. Default is 32\.
Returns:
- numpy.ndarray: An array of embeddings for the input texts.
"""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings = self.embeddings.embed_documents(batch)
embeddings.extend(batch_embeddings)
return np.array(embeddings)
def compute_similarity_matrix(self, embeddings):
"""
Computes a cosine similarity matrix for a given set of embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the input embeddings.
"""
return cosine_similarity(embeddings)
KnowledgeGraph Class: Building and Analyzing a Knowledge Graph
The KnowledgeGraph class creates and manages a graph of concepts derived from document content. It adds nodes for document splits, computes embeddings and extracts concepts using both spaCy and a large language model (LLM). The class also establishes connections (edges) between nodes based on the similarity of embeddings and shared concepts, with edge weights determined by both factors. This class helps structure and analyze information from a knowledge base, enabling efficient querying and exploration.
# Define the Concepts class
class Concepts(BaseModel):
concepts_list: List[str] = Field(description="List of concepts")
# Define the KnowledgeGraph class
class KnowledgeGraph:
def __init__(self):
"""
Initializes the KnowledgeGraph with a graph, lemmatizer and NLP model.
Attributes:
- graph: An instance of a networkx Graph.
- lemmatizer: An instance of WordNetLemmatizer.
- concept_cache: A dictionary to cache extracted concepts.
- nlp: An instance of a spaCy NLP model.
- edges_threshold: A float value that sets the threshold for adding edges based on similarity.
"""
self.graph = nx.Graph()
self.lemmatizer = WordNetLemmatizer()
self.concept_cache = {}
self.nlp = self._load_spacy_model()
self.edges_threshold = 0.8
def build_graph(self, splits, llm, embedding_model):
"""
Builds the knowledge graph by adding nodes, creating embeddings, extracting concepts and adding edges.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
- embedding_model: An instance of an embedding model.
Returns:
- None
"""
self._add_nodes(splits)
embeddings = self._create_embeddings(splits, embedding_model)
self._extract_concepts(splits, llm)
self._add_edges(embeddings)
def _add_nodes(self, splits):
"""
Adds nodes to the graph from the document splits.
Args:
- splits (list): A list of document splits.
Returns:
- None
"""
for i, split in enumerate(splits):
self.graph.add_node(i, content=split.page_content)
def _create_embeddings(self, splits, embedding_model):
"""
Creates embeddings for the document splits using the embedding model.
Args:
- splits (list): A list of document splits.
- embedding_model: An instance of an embedding model.
Returns:
- numpy.ndarray: An array of embeddings for the document splits.
"""
texts = [split.page_content for split in splits]
return embedding_model.embed_documents(texts)
def _compute_similarities(self, embeddings):
"""
Computes the cosine similarity matrix for the embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the embeddings.
"""
return cosine_similarity(embeddings)
def _load_spacy_model(self):
"""
Loads the spaCy NLP model, downloading it if necessary.
Args:
- None
Returns:
- spacy.Language: An instance of a spaCy NLP model.
"""
try:
return spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
download("en_core_web_sm")
return spacy.load("en_core_web_sm")
def _extract_concepts_and_entities(self, content, llm):
"""
Extracts concepts and named entities from the content using spaCy and a large language model.
Args:
- content (str): The content from which to extract concepts and entities.
- llm: An instance of a large language model.
Returns:
- list: A list of extracted concepts and entities.
"""
if content in self.concept_cache:
return self.concept_cache[content]
# Extract named entities using spaCy
doc = self.nlp(content)
named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
# Extract general concepts using LLM
concept_extraction_prompt = PromptTemplate(
input_variables=["text"],
template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
)
concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
general_concepts = concept_chain.invoke({"text": content}).concepts_list
# Combine named entities and general concepts
all_concepts = list(set(named_entities + general_concepts))
self.concept_cache[content] = all_concepts
return all_concepts
def _extract_concepts(self, splits, llm):
"""
Extracts concepts for all document splits using multi-threading.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
Returns:
- None
"""
with ThreadPoolExecutor() as executor:
future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i
for i, split in enumerate(splits)}
for future in tqdm(as_completed(future_to_node), total=len(splits), desc="Extracting concepts and entities"):
node = future_to_node[future]
concepts = future.result()
self.graph.nodes[node]['concepts'] = concepts
def _add_edges(self, embeddings):
"""
Adds edges to the graph based on the similarity of embeddings and shared concepts.
Args:
- embeddings (numpy.ndarray): An array of embeddings for the document splits.
Returns:
- None
"""
similarity_matrix = self._compute_similarities(embeddings)
num_nodes = len(self.graph.nodes)
for node1 in tqdm(range(num_nodes), desc="Adding edges"):
for node2 in range(node1 + 1, num_nodes):
similarity_score = similarity_matrix[node1][node2]
if similarity_score \> self.edges_threshold:
shared_concepts = set(self.graph.nodes[node1]['concepts']) & set(self.graph.nodes[node2]['concepts'])
edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
self.graph.add_edge(node1, node2, weight=edge_weight,
similarity=similarity_score,
shared_concepts=list(shared_concepts))
def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
"""
Calculates the weight of an edge based on similarity score and shared concepts.
Args:
- node1 (int): The first node.
- node2 (int): The second node.
- similarity_score (float): The similarity score between the nodes.
- shared_concepts (set): The set of shared concepts between the nodes.
- alpha (float, optional): The weight of the similarity score. Default is 0.7.
- beta (float, optional): The weight of the shared concepts. Default is 0.3.
Returns:
- float: The calculated weight of the edge.
"""
max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared \> 0 else 0
return alpha * similarity_score + beta * normalized_shared_concepts
def _lemmatize_concept(self, concept):
"""
Lemmatizes a given concept.
Args:
- concept (str): The concept to be lemmatized.
Returns:
- str: The lemmatized concept.
"""
return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])
QueryEngine Class: Handling Queries and Expanding Context
__init__: Initializes the QueryEngine with a vector store, knowledge graph, LLM and a chain to check answers.
_create_answer_check_chain: Creates a chain that checks if the context provides a complete answer to the query.
_check_answer: Checks if the given context fully answers the query, returning a boolean and the answer.
_expand_context: Expands the context by traversing the knowledge graph and updating it with the most relevant nodes.
query: Processes the query by retrieving relevant documents, expanding the context and generating a final answer.
_retrieve_relevant_documents: Retrieves documents relevant to the query using a vector store and a contextual compression retriever.
# Define the AnswerCheck class
class AnswerCheck(BaseModel):
is_complete: bool = Field(description="Whether the current context provides a complete answer to the query")
answer: str = Field(description="The current answer based on the context, if any")
# Define the QueryEngine class
class QueryEngine:
def __init__(self, vector_store, knowledge_graph, llm):
self.vector_store = vector_store
self.knowledge_graph = knowledge_graph
self.llm = llm
self.max_context_length = 4000
self.answer_check_chain = self._create_answer_check_chain()
def _create_answer_check_chain(self):
"""
Creates a chain to check if the context provides a complete answer to the query.
Args:
- None
Returns:
- Chain: A chain to check if the context provides a complete answer.
"""
answer_check_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
)
return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)
def _check_answer(self, query: str, context: str) -\> Tuple[bool, str]:
"""
Checks if the current context provides a complete answer to the query.
Args:
- query (str): The query to be answered.
- context (str): The current context.
Returns:
- tuple: A tuple containing:
- is_complete (bool): Whether the context provides a complete answer.
- answer (str): The answer based on the context, if complete.
"""
response = self.answer_check_chain.invoke({"query": query, "context": context})
return response.is_complete, response.answer
def _expand_context(self, query: str, relevant_docs) -\> Tuple[str, List[int], Dict[int, str], str]:
"""
Expands the context by traversing the knowledge graph using a Dijkstra-like approach.
This method implements a modified version of Dijkstra's algorithm to explore the knowledge graph,
prioritizing the most relevant and strongly connected information. The algorithm works as follows:
1\. Initialize:
- Start with nodes corresponding to the most relevant documents.
- Use a priority queue to manage the traversal order, where priority is based on connection strength.
- Maintain a dictionary of best known "distances" (inverse of connection strengths) to each node.
2\. Traverse:
- Always explore the node with the highest priority (strongest connection) next.
- For each node, check if we've found a complete answer.
- Explore the node's neighbors, updating their priorities if a stronger connection is found.
3\. Concept Handling:
- Track visited concepts to guide the exploration towards new, relevant information.
- Expand to neighbors only if they introduce new concepts.
4\. Termination:
- Stop if a complete answer is found.
- Continue until the priority queue is empty (all reachable nodes explored).
This approach ensures that:
- We prioritize the most relevant and strongly connected information.
- We explore new concepts systematically.
- We find the most relevant answer by following the strongest connections in the knowledge graph.
Args:
- query (str): The query to be answered.
- relevant_docs (List[Document]): A list of relevant documents to start the traversal.
Returns:
- tuple: A tuple containing:
- expanded_context (str): The accumulated context from traversed nodes.
- traversal_path (List[int]): The sequence of node indices visited.
- filtered_content (Dict[int, str]): A mapping of node indices to their content.
- final_answer (str): The final answer found, if any.
"""
# Initialize variables
expanded_context = ""
traversal_path = []
visited_concepts = set()
filtered_content = {}
final_answer = ""
priority_queue = []
distances = {} # Stores the best known "distance" (inverse of connection strength) to each node
print("\nTraversing the knowledge graph:")
# Initialize priority queue with closest nodes from relevant docs
for doc in relevant_docs:
# Find the most similar node in the knowledge graph for each relevant document
closest_nodes = self.vector_store.similarity_search_with_score(doc.page_content, k=1)
closest_node_content, similarity_score = closest_nodes[0]
# Get the corresponding node in our knowledge graph
closest_node = next(n for n in self.knowledge_graph.graph.nodes if self.knowledge_graph.graph.nodes[n]['content'] == closest_node_content.page_content)
# Initialize priority (inverse of similarity score for min-heap behavior)
priority = 1 / similarity_score
heapq.heappush(priority_queue, (priority, closest_node))
distances[closest_node] = priority
step = 0
while priority_queue:
# Get the node with the highest priority (lowest distance value)
current_priority, current_node = heapq.heappop(priority_queue)
# Skip if we've already found a better path to this node
if current_priority \> distances.get(current_node, float('inf')):
continue
if current_node not in traversal_path:
step += 1
traversal_path.append(current_node)
node_content = self.knowledge_graph.graph.nodes[current_node]['content']
node_concepts = self.knowledge_graph.graph.nodes[current_node]['concepts']
# Add node content to our accumulated context
filtered_content[current_node] = node_content
expanded_context += "\n" + node_content if expanded_context else node_content
# Log the current step for debugging and visualization
print(f"\nStep {step} - Node {current_node}:")
print(f"Content: {node_content[:100]}...")
print(f"Concepts: {', '.join(node_concepts)}")
print("-" * 50)
# Check if we have a complete answer with the current context
is_complete, answer = self._check_answer(query, expanded_context)
if is_complete:
final_answer = answer
break
# Process the concepts of the current node
node_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in node_concepts)
if not node_concepts_set.issubset(visited_concepts):
visited_concepts.update(node_concepts_set)
# Explore neighbors
for neighbor in self.knowledge_graph.graph.neighbors(current_node):
edge_data = self.knowledge_graph.graph[current_node][neighbor]
edge_weight = edge_data['weight']
# Calculate new distance (priority) to the neighbor
# Note: We use 1 / edge_weight because higher weights mean stronger connections
distance = current_priority + (1 / edge_weight)
# If we've found a stronger connection to the neighbor, update its distance
if distance \< distances.get(neighbor, float('inf')):
distances[neighbor] = distance
heapq.heappush(priority_queue, (distance, neighbor))
# Process the neighbor node if it's not already in our traversal path
if neighbor not in traversal_path:
step += 1
traversal_path.append(neighbor)
neighbor_content = self.knowledge_graph.graph.nodes[neighbor]['content']
neighbor_concepts = self.knowledge_graph.graph.nodes[neighbor]['concepts']
filtered_content[neighbor] = neighbor_content
expanded_context += "\n" + neighbor_content if expanded_context else neighbor_content
# Log the neighbor node information
print(f"\nStep {step} - Node {neighbor} (neighbor of {current_node}):")
print(f"Content: {neighbor_content[:100]}...")
print(f"Concepts: {', '.join(neighbor_concepts)}")
print("-" * 50)
# Check if we have a complete answer after adding the neighbor's content
is_complete, answer = self._check_answer(query, expanded_context)
if is_complete:
final_answer = answer
break
# Process the neighbor's concepts
neighbor_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in neighbor_concepts)
if not neighbor_concepts_set.issubset(visited_concepts):
visited_concepts.update(neighbor_concepts_set)
# If we found a final answer, break out of the main loop
if final_answer:
break
# If we haven't found a complete answer, generate one using the LLM
if not final_answer:
print("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
final_answer = response_chain.invoke(input_data)
return expanded_context, traversal_path, filtered_content, final_answer
def query(self, query: str) -\> Tuple[str, List[int], Dict[int, str]]:
"""
Processes a query by retrieving relevant documents, expanding the context and generating the final answer.
Args:
- query (str): The query to be answered.
Returns:
- tuple: A tuple containing:
- final_answer (str): The final answer to the query.
- traversal_path (list): The traversal path of nodes in the knowledge graph.
- filtered_content (dict): The filtered content of nodes.
"""
with get_openai_callback() as cb:
print(f"\nProcessing query: {query}")
relevant_docs = self._retrieve_relevant_documents(query)
expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query, relevant_docs)
if not final_answer:
print("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
response = response_chain.invoke(input_data)
final_answer = response
else:
print("\nComplete answer found during traversal.")
print(f"\nFinal Answer: {final_answer}")
print(f"\nTotal Tokens: {cb.total_tokens}")
print(f"Prompt Tokens: {cb.prompt_tokens}")
print(f"Completion Tokens: {cb.completion_tokens}")
print(f"Total Cost (USD): ${cb.total_cost}")
return final_answer, traversal_path, filtered_content
def _retrieve_relevant_documents(self, query: str):
"""
Retrieves relevant documents based on the query using the vector store.
Args:
- query (str): The query to be answered.
Returns:
- list: A list of relevant documents.
"""
print("\nRetrieving relevant documents...")
retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
compressor = LLMChainExtractor.from_llm(self.llm)
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
return compression_retriever.invoke(query)
Visualizer Class: Graph Traversal Visualization
The Visualizer class provides methods for visualizing the traversal of a knowledge graph. The visualize_traversal method highlights the traversal path, nodes and edges, with special markers for the start and end nodes and adds labels and colorbars for better clarity. It also displays a color gradient for edge weights and shows the filtered content of the nodes in the traversal path through print_filtered_content, printing the first 200 characters of each node's content.
# Import necessary libraries
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Define the Visualizer class
class Visualizer:
@staticmethod
def visualize_traversal(graph, traversal_path):
"""
Visualizes the traversal path on the knowledge graph with nodes, edges and traversal path highlighted.
Args:
- graph (networkx.Graph): The knowledge graph containing nodes and edges.
- traversal_path (list of int): The list of node indices representing the traversal path.
Returns:
- None
"""
traversal_graph = nx.DiGraph()
# Add nodes and edges from the original graph
for node in graph.nodes():
traversal_graph.add_node(node)
for u, v, data in graph.edges(data=True):
traversal_graph.add_edge(u, v, **data)
fig, ax = plt.subplots(figsize=(16, 12))
# Generate positions for all nodes
pos = nx.spring_layout(traversal_graph, k=1, iterations=50)
# Draw regular edges with color based on weight
edges = traversal_graph.edges()
edge_weights = [traversal_graph[u][v].get('weight', 0.5) for u, v in edges]
nx.draw_networkx_edges(traversal_graph, pos,
edgelist=edges,
edge_color=edge_weights,
edge_cmap=plt.cm.Blues,
width=2,
ax=ax)
# Draw nodes
nx.draw_networkx_nodes(traversal_graph, pos,
node_color='lightblue',
node_size=3000,
ax=ax)
# Draw traversal path with curved arrows
edge_offset = 0.1
for i in range(len(traversal_path) - 1):
start = traversal_path[i]
end = traversal_path[i + 1]
start_pos = pos[start]
end_pos = pos[end]
# Calculate control point for curve
mid_point = ((start_pos[0] + end_pos[0]) / 2, (start_pos[1] + end_pos[1]) / 2)
control_point = (mid_point[0] + edge_offset, mid_point[1] + edge_offset)
# Draw curved arrow
arrow = patches.FancyArrowPatch(start_pos, end_pos,
connectionstyle=f"arc3,rad={0.3}",
color='red',
arrow,
mutation_scale=20,
line,
linewidth=2,
zorder=4)
ax.add_patch(arrow)
# Prepare labels for the nodes
labels = {}
for i, node in enumerate(traversal_path):
concepts = graph.nodes[node].get('concepts', [])
label = f"{i + 1}. {concepts[0] if concepts else ''}"
labels[node] = label
for node in traversal_graph.nodes():
if node not in labels:
concepts = graph.nodes[node].get('concepts', [])
labels[node] = concepts[0] if concepts else ''
# Draw labels
nx.draw_networkx_labels(traversal_graph, pos, labels, font_size=8, font_weight="bold", ax=ax)
# Highlight start and end nodes
start_node = traversal_path[0]
end_node = traversal_path[-1]
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[start_node],
node_color='lightgreen',
node_size=3000,
ax=ax)
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[end_node],
node_color='lightcoral',
node_size=3000,
ax=ax)
ax.set_title("Graph Traversal Flow")
ax.axis('off')
# Add colorbar for edge weights
sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues, norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label('Edge Weight', rotation=270, labelpad=15)
# Add legend
regular_line = plt.Line2D([0], [0], color='blue', linewidth=2, label='Regular Edge')
traversal_line = plt.Line2D([0], [0], color='red', linewidth=2, line, label='Traversal Path')
start_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15, label='Start Node')
end_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', markersize=15, label='End Node')
legend = plt.legend(handles=[regular_line, traversal_line, start_point, end_point], loc='upper left', bbox_to_anchor=(0, 1), ncol=2)
legend.get_frame().set_alpha(0.8)
plt.tight_layout()
plt.show()
@staticmethod
def print_filtered_content(traversal_path, filtered_content):
"""
Prints the filtered content of visited nodes in the order of traversal.
Args:
- traversal_path (list of int): The list of node indices representing the traversal path.
- filtered_content (dict of int: str): A dictionary mapping node indices to their filtered content.
Returns:
- None
"""
print("\nFiltered content of visited nodes in order of traversal:")
for i, node in enumerate(traversal_path):
print(f"\nStep {i + 1} - Node {node}:")
print(f"Filtered Content: {filtered_content.get(node, 'No filtered content available')[:200]}...") # Print first 200 characters
print("-" * 50)
GraphRAG Class
The GraphRAG class manages the entire workflow for document processing, knowledge graph construction, querying and visualization.
- __init__: Initializes the GraphRAG system with necessary components like a large language model (LLM), embedding model, document processor, knowledge graph, query engine (set to None initially) and a visualizer.
- process_documents: Processes a list of documents by splitting them into chunks, embedding the chunks and building a knowledge graph. It then initializes the query engine for handling future queries.
- query: Handles queries by retrieving relevant information from the knowledge graph using the query engine. It also visualizes the traversal path through the knowledge graph and returns the final response to the query.
class GraphRAG:
def __init__(self):
"""
Initializes the GraphRAG system with components for document processing, knowledge graph construction,
querying and visualization.
Attributes:
- llm: An instance of a large language model (LLM) for generating responses.
- embedding_model: An instance of an embedding model for document embeddings.
- document_processor: An instance of the DocumentProcessor class for processing documents.
- knowledge_graph: An instance of the KnowledgeGraph class for building and managing the knowledge graph.
- query_engine: An instance of the QueryEngine class for handling queries (initialized as None).
- visualizer: An instance of the Visualizer class for visualizing the knowledge graph traversal.
"""
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
self.embedding_model = OpenAIEmbeddings()
self.document_processor = DocumentProcessor()
self.knowledge_graph = KnowledgeGraph()
self.query_engine = None
self.visualizer = Visualizer()
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into chunks, embedding them and building a knowledge graph.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- None
"""
splits, vector_store = self.document_processor.process_documents(documents)
self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)
def query(self, query: str):
"""
Handles a query by retrieving relevant information from the knowledge graph and visualizing the traversal path.
Args:
- query (str): The query to be answered.
Returns:
- str: The response to the query.
"""
response, traversal_path, filtered_content = self.query_engine.query(query)
if traversal_path:
self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
else:
print("No traversal path to visualize.")
return response
Defining the Document Location
The variable path stores the file path to a PDF document named "Understanding_Climate_Change.pdf" located in the specified directory on Google Drive.
path = "/content/drive/MyDrive/New 90 Projects/generative_ai_project/Fusion Retrieval: Combining Vector Search and BM25 for Enhanced Document Retrieval/Understanding_Climate_Change.pdf"
Loading and Selecting Documents from PDF
The code loads the "Understanding_Climate_Change.pdf" file using the PyPDFLoader class. It reads the content of the PDF and stores it in the documents variable. Then, it selects the first 10 documents (or chunks) from the loaded content for further processing or analysis.
loader = PyPDFLoader(path)
documents = loader.load()
documents = documents[:10]
Initializing the GraphRAG System
The graph_rag object is created as an instance of the GraphRAG class. This initializes all the components of the GraphRAG system, including the large language model (LLM), embedding model, document processor, knowledge graph, query engine and visualizer. The system is now ready to process documents and handle queries.
graph_rag = GraphRAG()
Processing Documents with GraphRAG
The process_documents method is called on the graph_rag object, passing the first 10 documents (documents) as input. This method splits the documents into chunks, embeds them and builds the knowledge graph using the DocumentProcessor and KnowledgeGraph classes. It also initializes the QueryEngine with the vector store and knowledge graph, preparing the system for querying.
graph_rag.process_documents(documents)
Querying the GraphRAG System
The query method is called on the graph_rag object with the query "what is the main cause of climate change?". The method retrieves relevant information from the knowledge graph, visualizes the traversal path and generates a response based on the available data. The final answer is stored in the response variable.
query = "what is the main cause of climate change?"
response = graph_rag.query(query)
Conclusion
The GraphRAG project successfully integrates advanced techniques in document retrieval, knowledge graph construction and query expansion to provide an efficient and accurate system for answering complex queries. By combining vector search, BM25 and knowledge graph traversal, it ensures contextually relevant and precise responses. The system's ability to process, embed and organize document chunks into a knowledge graph enhances its retrieval capabilities. Additionally, the use of large language models ensures that incomplete answers are supplemented, while visualization tools provide transparency in the query resolution process. Overall, GraphRAG offers a robust solution for applications requiring deep document understanding, making it an effective tool for intelligent information retrieval.
Challenges New Coders Might Face
Challenge: Handling Large PDF Files
Solution: Use batch processing, split documents into smaller chunks and utilize Colab's high-RAM runtime for better performance.Challenge: Slow Embedding Generation
Solution: Batch processing of text embeddings in smaller chunks (using a batch_size) helps manage resources efficiently. Additionally, embeddings can be precomputed and stored in FAISS for faster access during query handling.Challenge: Knowledge Graph Construction
Solution: Focus on cosine similarity for edge creation, which is computationally efficient and easy to implement. Using spaCy for named entity recognition and leveraging LLMs for concept extraction allows for better and more automated graph enrichment.Challenge: Query Mismatch in Retrieval
Solution: Use LLM-generated hypothetical documents to refine search queries and improve retrieval accuracy.Challenge: Dependency Installation Issues
Solution: Ensure Python 3.8+, install dependencies with !pip install --upgrade and use virtual environments for package management
FAQ
Question 1: What is GraphRAG?
Answer: GraphRAG is an advanced retrieval-augmented generation (RAG) system that integrates knowledge graphs with large language models (LLMs) to enhance document retrieval and generation. By leveraging the relational structure of graphs, GraphRAG improves the accuracy and contextual relevance of AI-generated insights.
Question 2: Is GraphRAG suitable for real-time applications?
Answer: GraphRAG can be adapted for real-time applications, though it may require optimization to meet the performance demands of such environments. The system's modular design allows for adjustments to balance accuracy and response time, making it feasible for real-time use cases.
Question 3: How does GraphRAG improve document retrieval?
Answer: GraphRAG enhances document retrieval by utilizing the interconnectedness of data within knowledge graphs. This approach allows for more precise and contextually appropriate responses to complex queries, addressing the limitations of traditional vector-based retrieval methods.
Question 4. How does GraphRAG handle user queries?
Answer: GraphRAG interprets user queries, retrieves relevant information from the knowledge graph and generates comprehensive responses using a large language model.
Question 5. Where can I learn more about GraphRAG?
Answer: For more detailed information and resources on GraphRAG, you can visit the official GraphRAG website. Microsoft GitHub
Additionally, the Neo4j blog offers an in-depth guide on GraphRAG patterns and implementations. Neo4j
These resources provide comprehensive insights into the concepts, applications and advancements related to GraphRAG.