import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, BlipProcessor, BlipForConditionalGeneration ) from threading import Thread from sentence_transformers import SentenceTransformer, util import requests from bs4 import BeautifulSoup from PIL import Image # --- CONFIGURATION --- # 1. LLM: TinyLlama print("Loading TinyLlama...") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # 2. Embedding Model: For Text RAG print("Loading Embedding Model...") embedder = SentenceTransformer('all-MiniLM-L6-v2') # 3. Vision Model: BLIP (for Image to Text) # We use this to convert images into text descriptions so TinyLlama can "read" them. print("Loading Vision Model (BLIP)...") vision_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") vision_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Device Setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) vision_model = vision_model.to(device) # --- GLOBAL STATE FOR RAG --- KNOWLEDGE_CHUNKS = [] KNOWLEDGE_EMBEDDINGS = None RAG_ENABLED = False # System content DEFAULT_SYSTEM_PROMPT = """You are TinyLlama, a friendly and helpful AI assistant. You are based on the TinyLlama-1.1B-Chat model.""" SYSTEM_CONTENT = DEFAULT_SYSTEM_PROMPT class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False # --- NEW TOOL FUNCTIONS --- def scrape_wikifandom(url): """Scrapes text content from a WikiFandom page.""" if "fandom.com" not in url: return "Error: Please provide a valid URL containing 'fandom.com'" try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers) if response.status_code != 200: return f"Error: Failed to fetch page (Status {response.status_code})" soup = BeautifulSoup(response.content, 'html.parser') # Fandom usually puts the main article text in 'mw-parser-output' content_div = soup.find('div', class_='mw-parser-output') if not content_div: # Fallback for some wiki layouts content_div = soup.find('div', id='content') if not content_div: return "Error: Could not parse content from this Fandom page." # Extract paragraphs paragraphs = content_div.find_all('p') text_content = "\n\n".join([p.get_text() for p in paragraphs if len(p.get_text()) > 50]) return text_content except Exception as e: return f"Error scraping URL: {str(e)}" def process_image_to_text(image): """Generates a caption for an image using BLIP.""" if image is None: return "" try: # Prepare image inputs = vision_processor(image, return_tensors="pt").to(device) # Generate caption out = vision_model.generate(**inputs, max_new_tokens=50) caption = vision_processor.decode(out[0], skip_special_tokens=True) return f"Image Context: The user uploaded an image that shows {caption}." except Exception as e: return f"Error processing image: {str(e)}" # --- RAG FUNCTIONS --- def process_knowledge_base(text_content): """Splits text into chunks and embeds them.""" global KNOWLEDGE_CHUNKS, KNOWLEDGE_EMBEDDINGS, RAG_ENABLED if not text_content.strip(): RAG_ENABLED = False return "Knowledge base cleared.", False # Chunking raw_chunks = text_content.split('\n\n') chunks = [chunk.strip() for chunk in raw_chunks if len(chunk.strip()) > 20] if not chunks: return "No valid text found to process.", False # Create Embeddings try: embeddings = embedder.encode(chunks, convert_to_tensor=True) KNOWLEDGE_CHUNKS = chunks KNOWLEDGE_EMBEDDINGS = embeddings RAG_ENABLED = True return f"Indexed {len(chunks)} chunks. RAG Ready.", True except Exception as e: return f"Error creating embeddings: {str(e)}", False def retrieve_context(query, top_k=3): if not RAG_ENABLED or KNOWLEDGE_EMBEDDINGS is None: return "" query_embedding = embedder.encode(query, convert_to_tensor=True) cos_scores = util.cos_sim(query_embedding, KNOWLEDGE_EMBEDDINGS)[0] top_results = torch.topk(cos_scores, k=min(top_k, len(KNOWLEDGE_CHUNKS))) retrieved_text = [] for score, idx in zip(top_results[0], top_results[1]): if score > 0.25: # Slightly lower threshold for broader context retrieved_text.append(KNOWLEDGE_CHUNKS[idx]) return "\n\n".join(retrieved_text) # --- PREDICTION FUNCTION --- def predict(message, history, system_content=None): current_system_content = system_content if system_content else SYSTEM_CONTENT context = "" if RAG_ENABLED: retrieved = retrieve_context(message) if retrieved: context = f"\nUse this context to answer:\n{retrieved}\n" message = f"{context}\nQuestion: {message}" history_transformer_format = history + [[message, ""]] stop = StopOnTokens() system_prompt = f"<|system|>\n{current_system_content}" conversation = "".join(["".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]]) for item in history_transformer_format]) messages = system_prompt + conversation model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=50, temperature=0.7, num_beams=1, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: partial_message += new_token if '' in partial_message: break yield partial_message # --- UI LOGIC --- def add_fandom_content(url, current_text): """Fetches fandom content and appends it to the textbox.""" scraped_text = scrape_wikifandom(url) if scraped_text.startswith("Error"): return current_text, scraped_text # Return error in status new_text = (current_text + "\n\n" + scraped_text).strip() return new_text, "Fandom content added to Knowledge Base text area." def add_image_content(image, current_text): """Analyzes image and appends description to textbox.""" description = process_image_to_text(image) if description.startswith("Error"): return current_text, description new_text = (current_text + "\n\n" + description).strip() return new_text, "Image analysis added. RAG now knows what this image looks like." # --- GRADIO INTERFACE --- with gr.Blocks(title="TinyLlama Multi-Source RAG") as demo: gr.Markdown("# 🦙 TinyLlama RAG (WikiFandom + Images)") gr.Markdown("Chat with TinyLlama. Build a knowledge base from text, WikiFandom URLs, or Images.") with gr.Row(): # Left Column: Chat with gr.Column(scale=2): chat_interface = gr.ChatInterface( predict, examples=['Who is in the image?', 'Tell me about the wiki page'], ) # Right Column: Tools with gr.Column(scale=1): # --- RAG INPUTS --- with gr.Accordion("📚 Knowledge Sources", open=True): # Main Text Area (Where all data ends up) kb_input = gr.Textbox( label="Compiled Knowledge Base", lines=6, placeholder="Data from Wiki or Images will appear here...", interactive=True ) with gr.Tab("🔗 WikiFandom"): url_input = gr.Textbox(label="Fandom URL", placeholder="https://starwars.fandom.com/wiki/Luke_Skywalker") scrape_btn = gr.Button("Scrape & Add Text") with gr.Tab("🖼️ Image Support"): img_input = gr.Image(type="pil", label="Upload Image") img_btn = gr.Button("Analyze & Add Description") # Build Button with gr.Row(): process_btn = gr.Button("⚡ Build Knowledge Base", variant="primary") rag_status = gr.Checkbox(label="RAG Active", interactive=False, value=False) status_output = gr.Textbox(label="Status", interactive=False) # System Prompt with gr.Accordion("⚙️ System Settings", open=False): system_content_input = gr.Textbox(value=SYSTEM_CONTENT, lines=2, label="System Prompt") # --- EVENT HANDLERS --- # 1. Scrape Fandom -> Append to Textbox scrape_btn.click( add_fandom_content, inputs=[url_input, kb_input], outputs=[kb_input, status_output] ) # 2. Analyze Image -> Append to Textbox img_btn.click( add_image_content, inputs=[img_input, kb_input], outputs=[kb_input, status_output] ) # 3. Build RAG Index process_btn.click( process_knowledge_base, inputs=[kb_input], outputs=[status_output, rag_status] ) if __name__ == "__main__": demo.launch()