import os import re import json import base64 import mimetypes import argparse import pathlib import io import requests from typing import List, Dict, Any, Optional, TypedDict, Literal, Tuple from dataclasses import dataclass, field from io import BytesIO from PIL import Image from uuid import uuid4 # --- LangGraph / LangChain --- from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver # --- OpenAI / Azure --- from openai import AzureOpenAI, OpenAI from dotenv import load_dotenv # --- HF Transformers & Diffusers (Local VLM) --- import torch from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, pipeline, BitsAndBytesConfig import transformers.utils as _hf_utils if not hasattr(_hf_utils, "FLAX_WEIGHTS_NAME"): _hf_utils.FLAX_WEIGHTS_NAME = "flax_model.msgpack" from diffusers import DiffusionPipeline try: import transformers.models.auto.video_processing_auto as _video_auto if getattr(_video_auto, "VIDEO_PROCESSOR_MAPPING_NAMES", None) is None: _video_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {} except Exception as _video_err: print(f"Warning: unable to patch video processor registry: {_video_err}") # --- Playwright (for screenshots) --- from playwright.sync_api import sync_playwright # --- Load Environment Variables --- load_dotenv() # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # ## SECTION 1: UNIFIED MODEL MANAGER (Singleton) # # Manages loading all models (Azure, Local VLM, Generator) # to ensure they are only loaded into memory once. # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # --- Configs from all files --- QWEN_VL_MODEL_NAME = os.getenv("QWEN_VL_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") SD_GENERATOR_MODEL = os.getenv("SD_GENERATOR_MODEL", "segmind/tiny-sd") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 class ModelManager: """Manages loading all models and API clients at startup.""" _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(ModelManager, cls).__new__(cls) return cls._instance def __init__(self): if not hasattr(self, 'vlm_model'): # Initialize only once print("Initializing and loading all models and clients...") # 1. Configure Azure Client self.AZURE_ENDPOINT = os.getenv("ENDPOINT_URL", "") self.AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "") if not self.AZURE_API_KEY or not self.AZURE_ENDPOINT: print("Warning: AZURE_OPENAI_API_KEY or ENDPOINT_URL not set.") else: self.azure_client = AzureOpenAI( azure_endpoint=self.AZURE_ENDPOINT, api_key=self.AZURE_API_KEY, api_version="2024-10-21" ) print("AzureOpenAI client loaded.") # 2. Configure OpenAI Client (for edit_node_tool) try: self.OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] self.openai_client = OpenAI(api_key=self.OPENAI_API_KEY) print("OpenAI client loaded.") except KeyError: print("Warning: OPENAI_API_KEY not set. GPT editor tool will not work.") self.openai_client = None # 3. Configure and load the Local VLM (Qwen) print(f"Loading local VLM: {QWEN_VL_MODEL_NAME}...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4" ) self.vlm_processor = AutoProcessor.from_pretrained(QWEN_VL_MODEL_NAME, trust_remote_code=True) self.vlm_model = AutoModelForVision2Seq.from_pretrained( QWEN_VL_MODEL_NAME, torch_dtype=DTYPE, device_map="auto", quantization_config=quantization_config, trust_remote_code=True ) print("Local VLM (Qwen) loaded.") # 4. Configure and load the Generator print(f"Loading image generator: {SD_GENERATOR_MODEL}...") self.generator_pipe = DiffusionPipeline.from_pretrained( SD_GENERATOR_MODEL, torch_dtype=DTYPE ) self.generator_pipe.enable_model_cpu_offload() print("Generator loaded.") print("All models and clients loaded and ready.") def get_azure_client(self) -> AzureOpenAI: if not hasattr(self, 'azure_client'): raise RuntimeError("Azure client not initialized. Set AZURE_OPENAI_API_KEY and ENDPOINT_URL.") return self.azure_client def get_openai_client(self) -> OpenAI: if not hasattr(self, 'openai_client') or self.openai_client is None: raise RuntimeError("OpenAI client not initialized. Set OPENAI_API_KEY.") return self.openai_client # --- VLM Chat (in asset tool) --- def chat_vlm(self, messages, temperature=0.2, max_new_tokens=2048): gen_kwargs = {"do_sample": temperature > 0, "max_new_tokens": max_new_tokens} if temperature > 0: gen_kwargs["temperature"] = temperature inputs = self.vlm_processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(self.vlm_model.device) with torch.no_grad(): output_ids = self.vlm_model.generate(**inputs, **gen_kwargs) gen_only = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], output_ids)] return self.vlm_processor.batch_decode(gen_only, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] def chat_llm(self, prompt: str): messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] return self.chat_vlm(messages, temperature=0.1, max_new_tokens=1024) # --- Generator (from asset_tool) --- def generate_image(self, prompt: str) -> Image.Image: print(f"Generating image with prompt: '{prompt}'") return self.generator_pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] # --- Azure Chat (from agent_azure_vlm) --- def chat_complete_azure(self, deployment: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int) -> str: client = self.get_azure_client() resp = client.chat.completions.create( model=deployment, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return (resp.choices[0].message.content or "").strip() # --- Initialize models ONCE --- models = ModelManager() # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # ## SECTION 2: ASSET-FINDING TOOL # # This is the self-contained graph for finding/generating assets. # It will be used as a tool by the "Brain" AND by the Azure pipeline. # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- ### --- Utilities from asset_tool.py --- def load_image(path: str) -> Image.Image: return Image.open(path).convert("RGB") def b64img(pil_img: Image.Image) -> str: buf = io.BytesIO() pil_img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") ### --- State from asset_tool.py --- class AssetGraphState(TypedDict): """State for the asset-finding subgraph.""" instructions: str bounding_box: Tuple[int, int] search_query: str found_image_url: Optional[str] final_asset_path: Optional[str] ### --- Nodes from asset_tool.py --- def asset_prepare_search_query_node(state: AssetGraphState) -> dict: print("---(Asset Tool) NODE: Prepare Search Query---") prompt = f"""You are an expert at refining search queries. Extract only the essential visual keywords. **CRITICAL INSTRUCTIONS:** - DO NOT include words related to licensing. - DO NOT include quotation marks. User's request: "{state['instructions']}" Respond with ONLY the refined search query.""" raw_query = models.chat_llm(prompt) search_query = raw_query.strip().replace('"', '') print(f"Prepared search query: '{search_query}'") return {"search_query": search_query} def asset_generate_image_node(state: AssetGraphState) -> dict: print("---(Asset Tool) NODE: Generate Image---") prompt = state["instructions"] generated_image = models.generate_image(prompt) output_dir = pathlib.Path("Outputs/Assets") output_dir.mkdir(parents=True, exist_ok=True) filename = f"generated_{uuid4()}.png" full_save_path = output_dir / filename generated_image.save(full_save_path) print(f"Image generated and saved to {full_save_path}") html_path = pathlib.Path("Assets") / filename final_asset_path = str(html_path.as_posix()) return {"final_asset_path": final_asset_path} def asset_download_and_resize_node(state: AssetGraphState) -> dict: print("---(Asset Tool) NODE: Download and Resize---") image_url = state.get("found_image_url") try: response = requests.get(image_url, timeout=10) response.raise_for_status() img = Image.open(BytesIO(response.content)) img.thumbnail(state['bounding_box']) output_dir = pathlib.Path("Outputs/Assets") output_dir.mkdir(parents=True, exist_ok=True) filename = f"asset_{uuid4()}.png" full_save_path = output_dir / filename img.save(full_save_path) print(f"Image saved and resized to {full_save_path}") html_path = pathlib.Path("Assets") / filename final_asset_path = str(html_path.as_posix()) return {"final_asset_path": final_asset_path} except Exception as e: print(f"Error processing image: {e}") return {"final_asset_path": None} def asset_route_after_search(state: AssetGraphState) -> str: if state.get("found_image_url"): return "download_and_resize" else: print("Search failed. Routing to generate a new image.") return "generate_image" def asset_pexels_search_node(state: AssetGraphState) -> dict: print("---(Asset Tool) TOOL: Searching Pexels---") api_key = os.getenv("PEXELS_API_KEY") search_query = state.get("search_query") if not api_key: print("Warning: PEXELS_API_KEY not set. Skipping search.") return {"found_image_url": None} if not search_query: return {"found_image_url": None} headers = {"Authorization": api_key} params = {"query": search_query, "per_page": 1} try: response = requests.get("https://api.pexels.com/v1/search", headers=headers, params=params, timeout=10) response.raise_for_status() if response.json().get('photos'): image_url = response.json()['photos'][0]['src']['original'] print(f"Found a candidate image: {image_url}") return {"found_image_url": image_url} except requests.exceptions.RequestException as e: print(f"Pexels API Error: {e}") return {"found_image_url": None} ### --- Graph Builder from asset_tool.py --- def build_asset_graph(): workflow = StateGraph(AssetGraphState) workflow.add_node("prepare_search_query", asset_prepare_search_query_node) workflow.add_node("pexels_search", asset_pexels_search_node) workflow.add_node("generate_image", asset_generate_image_node) workflow.add_node("download_and_resize", asset_download_and_resize_node) workflow.set_entry_point("prepare_search_query") workflow.add_edge("prepare_search_query", "pexels_search") workflow.add_conditional_edges("pexels_search", asset_route_after_search) workflow.add_edge("generate_image", END) workflow.add_edge("download_and_resize", END) return workflow.compile() # --- Compile the graph --- asset_agent_app = build_asset_graph() # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # ## SECTION 3: CODE EDITOR TOOL # # This is the self-contained graph for editing HTML. # It will be used as a tool by the "Brain". # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- class CodeEditorState(TypedDict): html_code: str user_request: str model_choice: Literal["gpt-4o-mini-2", "qwen-local"] messages: list[str] EDITOR_SYSTEM_PROMPT = """ You are an expert senior web developer specializing in HTML, CSS, and JavaScript. Your task is to take an existing HTML file, a user's request for changes, and to output the *new, complete, and updated HTML file*. **CRITICAL RULES:** 1. **Output ONLY the Code:** Your entire response MUST be *only* the raw, updated HTML code. 2. **No Conversation:** Do NOT include "Here is the updated code:", "I have made the following changes:", or any other explanatory text, comments, or markdown formatting. 3. **Return the Full File:** Always return the complete HTML file, from `` to ``, incorporating the requested changes. Do not return just a snippet. """ def _clean_llm_output(code: str) -> str: """Removes common markdown formatting.""" code = code.strip() if code.startswith("```html"): code = code[7:] if code.endswith("```"): code = code[:-3] return code.strip() def _call_gpt_editor(html_code: str, user_request: str, model: str) -> str: """Uses OpenAI (GPT) model.""" user_prompt = f"**User Request:**\n{user_request}\n\n**Original HTML Code:**\n```html\n{html_code}\n```\n\n**Your updated HTML Code:**" try: client = models.get_openai_client() response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": EDITOR_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ], temperature=0.0, max_tokens=8192, ) edited_code = response.choices[0].message.content return _clean_llm_output(edited_code) except Exception as e: print(f"Error calling OpenAI API: {e}") return f"\n{html_code}" def _call_qwen_editor(html_code: str, user_request: str) -> str: """Uses Local Qwen VLM.""" user_prompt = f"**User Request:**\n{user_request}\n\n**Original HTML Code:**\n```html\n{html_code}\n```\n\n**Your updated HTML Code:**" messages = [ {"role": "system", "content": [{"type": "text", "text": EDITOR_SYSTEM_PROMPT}]}, {"role": "user", "content": [{"type": "text", "text": user_prompt}]} ] try: edited_code = models.chat_vlm(messages, temperature=0.0, max_new_tokens=8192) return _clean_llm_output(edited_code) except Exception as e: print(f"Error calling local Qwen VLM: {e}") return f"\n{html_code}" def node_edit_code(state: CodeEditorState) -> dict: print("---(Edit Tool) NODE: Edit Code---") html_code, user_request, model_choice = state['html_code'], state['user_request'], state['model_choice'] messages = state.get('messages', []) if not user_request: return {"messages": messages + ["No user request provided. Skipping edit."]} try: if "gpt" in model_choice.lower(): new_html_code = _call_gpt_editor(html_code, user_request, model_choice) else: new_html_code = _call_qwen_editor(html_code, user_request) msg = f"Code edit complete using {model_choice}." print(msg) return {"html_code": new_html_code, "user_request": "", "messages": messages + [msg]} except Exception as e: error_msg = f"Error in code editing node: {e}" print(error_msg) return {"html_code": html_code, "messages": messages + [error_msg]} def build_edit_graph(): workflow = StateGraph(CodeEditorState) workflow.add_node("edit_code", node_edit_code) workflow.set_entry_point("edit_code") workflow.add_edge("edit_code", END) return workflow.compile(checkpointer=MemorySaver()) # --- Compile the graph --- edit_agent_app = build_edit_graph() # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- # ## SECTION 4: AZURE VLM PIPELINE (RE-ORDERED) # # This pipeline is reordered to be much faster. # 1. CodeGen runs FIRST, creating placeholders. # 2. A fast regex parser finds the placeholders. # 3. Asset search runs. # 4. A patcher node inserts the asset paths. # 5. Scoring & Refinement run as normal. # # This completely removes the slow local VLM call from this graph. # ---------------------------------------------------------------------- # ---------------------------------------------------------------------- ## --- Helpers --- _SCORE_KEYS = ["aesthetics","completeness","layout_fidelity","text_legibility","visual_balance"] def _section(text: str, name: str) -> str: pat = rf"{name}:\s*\n(.*?)(?=\n[A-Z_]+:\s*\n|\Z)" m = re.search(pat, text, flags=re.S) return m.group(1).strip() if m else "" def _score_val(block: str, key: str, default: int = 0) -> int: m = re.search(rf"{key}\s*:\s*(-?\d+)", block, flags=re.I) try: return int(m.group(1)) if m else default except: return default def encode_image_to_data_url(path: str) -> str: mime = mimetypes.guess_type(path)[0] or "image/png" with open(path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") return f"data:{mime};base64,{b64}" def extract_html(text: str) -> str: m = re.search(r"```html(.*?)```", text, flags=re.S|re.I) if m: return m.group(1).strip() i = text.lower().find(" block; no external CSS/JS. - Define CSS variables from the palette and use them consistently. - Implement the layout: container max-width, gaps, grid columns, and stacking rules per breakpoints. - **CRITICAL ASSET RULE**: If you need an image (logo, hero, card image, etc.), you MUST use a placeholder in this **exact** format: (Example: ) - **DO NOT** use the `ASSET_PATHS` variable, it will be empty. - Output ONLY valid HTML starting with and ending with . """ SCORING_RUBRIC = r""" You are an experienced front-end engineer. Compare two images: (A) the original wireframe, and (B) the generated HTML rendering, and read the HTML/CSS code used for (B). Return a PLAIN-TEXT report with the following sections EXACTLY in this order (no JSON, no code fences around the whole report): SCORES: aesthetics: <0-10> completeness: <0-10> layout_fidelity: <0-10> # be harsh; row alignment, relative sizes and positions must match A text_legibility: <0-10> visual_balance: <0-10> aggregate: # mean of the five scores ISSUES_TOP3: - - - LAYOUT_DIFFS: - component: a_bbox_pct: [x,y,w,h] # approx percentages (0–100) of page width/height in A b_bbox_pct: [x,y,w,h] # same for B fix: CSS_PATCH: ```css /* <= 40 lines, use existing selectors where possible; use px and hex colors */ .selector { property: value; } /* ... */ ``` HTML_EDITS: - REGENERATE_PROMPT: <1–4 lines with exact grid, gaps (px), radii (px), hex colors, and font sizes to rebuild if needed> FEEDBACK: """ REFINE_SYSTEM = "You are a senior frontend engineer who strictly applies critique to improve HTML/CSS while matching the wireframe." REFINE_PROMPT = """ You are given: 1) (A) the original wireframe image 2) The CURRENT HTML (single-file) that produced (B) the rendering 3) A critique ("feedback") produced by a rubric-based comparison of A vs B Task: - Produce a NEW single-file HTML that addresses EVERY feedback point while staying faithful to A. - Fix layout fidelity (columns, spacing, alignment), completeness (ensure all components in A exist), typography/contrast for legibility, and overall aesthetics and balance. - Keep it self-contained (inline