import gradio as gr import PyPDF2 import io from together import Together from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain.llms.base import LLM from typing import List, Optional import traceback # --------------------------- # WRAP TOGETHER API AS LLM # --------------------------- class TogetherLLM(LLM): client: Together = None model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo" temperature: float = 0.3 max_tokens: int = 1000 def __init__(self, client, model="meta-llama/Llama-3.3-70B-Instruct-Turbo", temperature=0.3, max_tokens=1000, **kwargs): super().__init__(**kwargs) object.__setattr__(self, 'client', client) object.__setattr__(self, 'model', model) object.__setattr__(self, 'temperature', temperature) object.__setattr__(self, 'max_tokens', max_tokens) @property def _llm_type(self) -> str: return "together-llm" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: try: response = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], max_tokens=self.max_tokens, temperature=self.temperature, ) return response.choices[0].message.content.strip() except Exception as e: return f"Error generating response: {str(e)}" class Config: arbitrary_types_allowed = True # --------------------------- # PDF TEXT EXTRACTION # --------------------------- def extract_text_from_pdf(pdf_file): """Extract text from PDF with page references""" docs = [] try: print("Starting PDF extraction...") # Handle different input types if hasattr(pdf_file, 'name'): # File uploaded through Gradio with open(pdf_file.name, 'rb') as file: pdf_content = file.read() elif hasattr(pdf_file, "read"): pdf_content = pdf_file.read() if hasattr(pdf_file, "seek"): pdf_file.seek(0) else: pdf_content = pdf_file pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) print(f"PDF has {len(pdf_reader.pages)} pages") for page_num, page in enumerate(pdf_reader.pages, start=1): try: page_text = page.extract_text() if page_text and page_text.strip(): docs.append(Document( page_content=page_text.strip(), metadata={"page": page_num, "source": "financial_policy"} )) print(f"Extracted text from page {page_num}: {len(page_text)} characters") else: docs.append(Document( page_content="[No extractable text found on this page]", metadata={"page": page_num, "source": "financial_policy"} )) except Exception as e: print(f"Error extracting page {page_num}: {str(e)}") docs.append(Document( page_content=f"[Error extracting page {page_num}: {str(e)}]", metadata={"page": page_num, "source": "financial_policy"} )) print(f"Total documents extracted: {len(docs)}") return docs except Exception as e: print(f"Error in PDF extraction: {str(e)}") traceback.print_exc() return [Document(page_content=f"Error extracting text: {str(e)}", metadata={"page": -1})] # --------------------------- # BUILD KNOWLEDGE BASE (FAISS) # --------------------------- def build_vector_db(docs): """Convert extracted documents into FAISS vector DB""" try: print("Building vector database...") text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100, separators=["\n\n", "\n", ". ", " ", ""] ) split_docs = text_splitter.split_documents(docs) print(f"Split into {len(split_docs)} chunks") # Initialize embeddings embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) print("Embeddings model loaded") # Create FAISS database db = FAISS.from_documents(split_docs, embeddings) print("Vector database created successfully") return db except Exception as e: print(f"Error building vector database: {str(e)}") traceback.print_exc() return None # --------------------------- # CHATBOT PIPELINE # --------------------------- def create_chatbot(api_key, db): """Set up ConversationalRetrievalChain with memory""" try: print("Creating chatbot...") client = Together(api_key=api_key) llm = TogetherLLM(client=client) retriever = db.as_retriever( search_type="similarity", search_kwargs={"k": 4} ) memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer" ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, return_source_documents=True, verbose=True, ) print("Chatbot created successfully") return qa_chain except Exception as e: print(f"Error creating chatbot: {str(e)}") traceback.print_exc() return None # --------------------------- # GRADIO APP # --------------------------- def create_app(): with gr.Blocks(title="šŸ“Š Financial Policy Document Chatbot", theme=gr.themes.Soft()) as app: gr.Markdown("# šŸ“Š Financial Policy Document Chatbot") gr.Markdown(""" Upload a financial policy PDF document and ask questions about its content. The chatbot will provide answers with page references from the document. """) with gr.Row(): with gr.Column(scale=1): api_key_input = gr.Textbox( label="Together API Key", placeholder="Enter your Together API key here...", type="password", ) pdf_file = gr.File( label="Upload Financial Policy PDF", file_types=[".pdf"], ) process_button = gr.Button("šŸ“„ Process PDF", variant="primary") status_message = gr.Textbox(label="Status", interactive=False, lines=3) with gr.Column(scale=2): chatbot = gr.Chatbot(label="Chat with Financial Policy Document", height=500) with gr.Row(): question = gr.Textbox( label="Ask a question about the document", placeholder="Example: What is the budget allocation for infrastructure?", lines=2, scale=4 ) submit_button = gr.Button("šŸ” Ask", variant="secondary", scale=1) gr.Markdown(""" **Sample Questions:** - What is the debt policy outlined in the document? - How much budget is allocated for infrastructure? - What are the revenue sources mentioned? - What are the key financial objectives? """) # State variables db_state = gr.State() qa_chain_state = gr.State() # Event handlers def process_pdf_handler(pdf_file, api_key): try: if pdf_file is None: return "āš ļø Please upload a PDF file.", None, None if not api_key or api_key.strip() == "": return "āš ļø Please enter your Together API key.", None, None status_msg = "šŸ”„ Processing PDF... This may take a few moments." yield status_msg, None, None # Extract text from PDF docs = extract_text_from_pdf(pdf_file) if not docs or len(docs) == 0: yield "āš ļø No text could be extracted from the PDF.", None, None return # Check if extraction was successful valid_docs = [doc for doc in docs if not doc.page_content.startswith("[Error") and not doc.page_content.startswith("[No extractable")] if len(valid_docs) == 0: yield "āš ļø No readable text found in the PDF.", None, None return status_msg = f"šŸ“„ Extracted text from {len(docs)} pages. Building search database..." yield status_msg, None, None # Build vector database db = build_vector_db(docs) if db is None: yield "āš ļø Failed to build search database.", None, None return status_msg = f"šŸ” Search database created. Setting up chatbot..." yield status_msg, None, None # Create chatbot qa_chain = create_chatbot(api_key, db) if qa_chain is None: yield "āš ļø Failed to create chatbot.", None, None return final_status = f"āœ… Successfully processed PDF with {len(docs)} pages. Ready to answer questions!" yield final_status, db, qa_chain except Exception as e: error_msg = f"āŒ Error processing PDF: {str(e)}" print(f"Process PDF Error: {str(e)}") traceback.print_exc() yield error_msg, None, None def chat_handler(user_question, qa_chain, history): if not user_question or user_question.strip() == "": return history, history, "" if qa_chain is None: return history + [(user_question, "āš ļø Please process a PDF document first.")], history, "" try: # Get response from the chain result = qa_chain({"question": user_question}) answer = result["answer"] # Add source references if "source_documents" in result and result["source_documents"]: pages = [] for doc in result["source_documents"]: if "page" in doc.metadata: pages.append(doc.metadata["page"]) if pages: unique_pages = sorted(set(pages)) if len(unique_pages) == 1: answer += f"\n\nšŸ“Œ **Reference:** Page {unique_pages[0]}" else: answer += f"\n\nšŸ“Œ **References:** Pages {', '.join(map(str, unique_pages))}" new_history = history + [(user_question, answer)] return new_history, new_history, "" except Exception as e: error_response = f"āŒ Error processing question: {str(e)}" print(f"Chat Error: {str(e)}") traceback.print_exc() return history + [(user_question, error_response)], history, "" def clear_input(): return "" # Bind events process_button.click( fn=process_pdf_handler, inputs=[pdf_file, api_key_input], outputs=[status_message, db_state, qa_chain_state], ) submit_button.click( fn=chat_handler, inputs=[question, qa_chain_state, chatbot], outputs=[chatbot, chatbot, question], ) question.submit( fn=chat_handler, inputs=[question, qa_chain_state, chatbot], outputs=[chatbot, chatbot, question], ) return app # --------------------------- # MAIN EXECUTION # --------------------------- if __name__ == "__main__": app = create_app() app.launch( share=True, server_name="0.0.0.0", server_port=7860, debug=True )