Spaces:
Runtime error
Runtime error
| """ | |
| SCRIPT 2/5: augmenter.py - Context Enhancement and Prompt Engineering for Shoe RAG Pipeline | |
| Colab - https://colab.research.google.com/drive/1rq-ywjykHBw7xPXCmd3DmZdK6T9bhDtA?usp=sharing | |
| This script handles the AUGMENTATION phase of the RAG pipeline, including: | |
| - Query classification and analysis | |
| - Context formatting and enhancement | |
| - Prompt engineering with different strategies | |
| - Advanced prompt templates for different query types | |
| Key Concepts: | |
| - Prompt Engineering: Designing effective prompts to guide LLM responses | |
| - Context Enhancement: Structuring retrieved data for optimal LLM understanding | |
| - Query Classification: Determining intent to apply appropriate prompt strategies | |
| - Template-based Prompting: Using structured templates for consistent results | |
| Required Dependencies: | |
| - typing: Type hints for better code structure | |
| - enum: For defining query types | |
| Commands to run: | |
| # Test query classification | |
| python augmenter.py --query "recommend running shoes for men" --classify-only --search-type auto | |
| # Test context formatting with real data | |
| python augmenter.py --query "show me casual sneakers" --test-formatting | |
| # Generate prompt for recommendation query with real data | |
| python augmenter.py --query "recommend comfortable shoes" --generate-prompt | |
| # Generate prompt for search query with real data | |
| python augmenter.py --query "blue sneakers" --generate-prompt | |
| # Test with image search context | |
| python augmenter.py --query "hf_shoe_images/shoe_0000.jpg" --search-type image --generate-prompt | |
| # Test with auto search type detection | |
| python augmenter.py --query "recommend shoes" --generate-prompt | |
| # Use custom database settings | |
| python augmenter.py --query "sneakers" --generate-prompt --database "myntra_shoes_db" --table-name "myntra_shoes_table" | |
| """ | |
| import argparse | |
| from enum import Enum | |
| from typing import Any, Dict, List | |
| # Import retriever components | |
| from retriever import MyntraShoesEnhanced, run_shoes_search | |
| class QueryType(Enum): | |
| """Query types for different shoe-related interactions.""" | |
| RECOMMENDATION = "recommendation" | |
| SEARCH = "search" | |
| class SimpleShoePrompts: | |
| """AUGMENTATION: Simplified prompt system for shoe RAG with context enhancement.""" | |
| def __init__(self): | |
| self.system_prompts = { | |
| "recommendation": """You are a helpful assistant. Choose from the given shoe options and give a short, simple recommendation. Do not make up any information.""", | |
| "search": """You are a knowledgeable shoe assistant. Help customers understand the available shoe options | |
| that match their search criteria, providing detailed information about features and benefits.""", | |
| } | |
| def classify_query(self, query: str) -> QueryType: | |
| """Classify query into recommendation or search type.""" | |
| query_lower = query.lower() | |
| if any( | |
| word in query_lower | |
| for word in ["recommend", "suggest", "best", "need", "looking for"] | |
| ): | |
| return QueryType.RECOMMENDATION | |
| else: | |
| return QueryType.SEARCH | |
| def format_shoes_context(self, shoes: List[Dict[str, Any]]) -> str: | |
| """AUGMENTATION: Format retrieved shoes into readable context for LLM.""" | |
| formatted_shoes = [] | |
| for i, shoe in enumerate(shoes, 1): | |
| # Keep it simple - just basic info | |
| product_type = shoe.get("product_type", "Shoe") | |
| gender = shoe.get("gender", "") | |
| if gender: | |
| shoe_name = f"{product_type} for {gender}" | |
| else: | |
| shoe_name = product_type | |
| # Add basic color info if available | |
| color = shoe.get("color", "") | |
| if color and color not in ["None", None, ""]: | |
| shoe_name += f" ({color})" | |
| formatted_shoes.append(f"{i}. {shoe_name}") | |
| return "\n".join(formatted_shoes) | |
| def generate_prompt( | |
| self, query: str, shoes: List[Dict[str, Any]], search_type: str = "text" | |
| ) -> str: | |
| """AUGMENTATION: Generate complete prompt based on query type and retrieved context.""" | |
| # If it's an image search, always treat as search query type | |
| if search_type == "image": | |
| query_type = QueryType.SEARCH | |
| else: | |
| query_type = self.classify_query(query) | |
| system_prompt = self.system_prompts[query_type.value] | |
| context = self.format_shoes_context(shoes) | |
| if query_type == QueryType.RECOMMENDATION: | |
| # Add a summary to guide recommendations | |
| intent_summary = ( | |
| f"Based on the query, the user is likely looking for {query.lower()}." | |
| ) | |
| user_prompt = f"""{intent_summary} | |
| Available Options: | |
| {context} | |
| Your task: | |
| - Recommend the best option(s) that align most closely with the query. | |
| - Reference specific attributes (e.g., gender, product type, color, or other features) in your reasoning. | |
| - Avoid adding details not provided in the context. | |
| Provide your recommendation in 2-3 sentences.""" | |
| else: # SEARCH | |
| user_prompt = f"""Here are shoes matching: "{query}" | |
| Search Results: | |
| {context} | |
| Explain how well these shoes meet the search criteria and highlight their relevant features.""" | |
| return f"{system_prompt}\n\n{user_prompt}" | |
| def detect_search_type(search_query) -> str: | |
| """Auto-detect search type based on query content (matches retriever.py logic).""" | |
| # Auto-detect search type | |
| if isinstance(search_query, str): | |
| if search_query.endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif")): | |
| # Image file path | |
| return "image" | |
| else: | |
| # Text query | |
| return "text" | |
| elif hasattr(search_query, "save"): # PIL Image object | |
| return "image" | |
| else: | |
| return "text" | |
| def get_real_shoes_data( | |
| query: str, | |
| search_type: str = "text", | |
| database: str = "myntra_shoes_db", | |
| table_name: str = "myntra_shoes_table", | |
| limit: int = 3, | |
| ) -> List[Dict[str, Any]]: | |
| """Get real shoes data from retriever for testing purposes.""" | |
| try: | |
| results, _ = run_shoes_search( | |
| database=database, | |
| table_name=table_name, | |
| schema=MyntraShoesEnhanced, | |
| search_query=query, | |
| limit=limit, | |
| search_type=search_type, | |
| output_folder="output_augmenter", | |
| ) | |
| return results | |
| except Exception as e: | |
| raise Exception( | |
| f"Could not retrieve real data: {e}. Please ensure the database is set up correctly." | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Augmentation component for Shoe RAG Pipeline" | |
| ) | |
| parser.add_argument("--query", type=str, required=True, help="Query to process") | |
| parser.add_argument( | |
| "--search-type", | |
| choices=["auto", "text", "image"], | |
| default="auto", | |
| help="Search type for prompt generation (auto-detect or force specific type)", | |
| ) | |
| parser.add_argument( | |
| "--classify-only", action="store_true", help="Only classify the query type" | |
| ) | |
| parser.add_argument( | |
| "--test-formatting", | |
| action="store_true", | |
| help="Test context formatting with real data", | |
| ) | |
| parser.add_argument( | |
| "--generate-prompt", action="store_true", help="Generate complete prompt" | |
| ) | |
| parser.add_argument( | |
| "--database", | |
| type=str, | |
| default="myntra_shoes_db", | |
| help="Database path for real data", | |
| ) | |
| parser.add_argument( | |
| "--table-name", | |
| type=str, | |
| default="myntra_shoes_table", | |
| help="Table name for real data", | |
| ) | |
| parser.add_argument( | |
| "--limit", type=int, default=3, help="Number of shoes to retrieve for testing" | |
| ) | |
| args = parser.parse_args() | |
| # Initialize prompt manager | |
| prompt_manager = SimpleShoePrompts() | |
| # Auto-detect search type if needed | |
| if args.search_type == "auto": | |
| detected_search_type = detect_search_type(args.query) | |
| if detected_search_type == "image": | |
| print(f"πΌοΈ Detected image search: {args.query}") | |
| else: | |
| print(f"π Detected text search: {args.query}") | |
| else: | |
| detected_search_type = args.search_type | |
| # Classify query | |
| query_type = prompt_manager.classify_query(args.query) | |
| print("=" * 60) | |
| print("π AUGMENTATION RESULTS") | |
| print("=" * 60) | |
| print(f"Query: {args.query}") | |
| print(f"Query Type: {query_type.value}") | |
| print(f"Search Type: {args.search_type} β {detected_search_type}") | |
| if args.classify_only: | |
| print("\nπ― Query Classification Complete!") | |
| elif args.test_formatting: | |
| # Test context formatting with real data | |
| shoes_data = get_real_shoes_data( | |
| query=args.query, | |
| search_type=detected_search_type, | |
| database=args.database, | |
| table_name=args.table_name, | |
| limit=args.limit, | |
| ) | |
| formatted_context = prompt_manager.format_shoes_context(shoes_data) | |
| print(f"\nπ Context Formatting Test (Real Data):") | |
| print("-" * 40) | |
| print(f"Real Shoes Data:") | |
| for i, shoe in enumerate(shoes_data, 1): | |
| print(f" {i} {shoe}") | |
| print("\nFormatted Context:") | |
| print("-" * 40) | |
| print(formatted_context) | |
| elif args.generate_prompt: | |
| # Generate complete prompt with real data | |
| shoes_data = get_real_shoes_data( | |
| query=args.query, | |
| search_type=detected_search_type, | |
| database=args.database, | |
| table_name=args.table_name, | |
| limit=args.limit, | |
| ) | |
| complete_prompt = prompt_manager.generate_prompt( | |
| args.query, shoes_data, detected_search_type | |
| ) | |
| print(f"\nπ Complete Prompt Generation (Real Data):") | |
| print("-" * 40) | |
| print("System Prompt:") | |
| print(prompt_manager.system_prompts[query_type.value]) | |
| print("\nFormatted Context:") | |
| formatted_context = prompt_manager.format_shoes_context(shoes_data) | |
| print(formatted_context) | |
| print("\nComplete Prompt:") | |
| print("-" * 40) | |
| print(complete_prompt) | |
| else: | |
| # Show all information with real data | |
| shoes_data = get_real_shoes_data( | |
| query=args.query, | |
| search_type=detected_search_type, | |
| database=args.database, | |
| table_name=args.table_name, | |
| limit=args.limit, | |
| ) | |
| formatted_context = prompt_manager.format_shoes_context(shoes_data) | |
| complete_prompt = prompt_manager.generate_prompt( | |
| args.query, shoes_data, detected_search_type | |
| ) | |
| print(f"\nπ Context Formatting (Real Data):") | |
| print("-" * 40) | |
| print(formatted_context) | |
| print("\nπ Complete Prompt:") | |
| print("-" * 40) | |
| print(complete_prompt) | |
| print(f"\nβ Augmentation Complete! (Used Real Data)") | |