Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import gradio as gr | |
| from PIL import Image as PILImage | |
| import torchvision.transforms.functional as TF | |
| import numpy as np | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| import torch | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| from qwen_vl_utils import process_vision_info | |
| import re | |
| import io | |
| import base64 | |
| import cv2 | |
| from typing import List, Tuple, Optional | |
| import sys | |
| import spaces | |
| def add_sam2_to_path(): | |
| sam2_dir = os.path.abspath("third_party/sam2") | |
| if sam2_dir not in sys.path: | |
| sys.path.insert(0, sam2_dir) | |
| return sam2_dir | |
| def install_sam2(): | |
| sam2_dir = "third_party/sam2" | |
| if not os.path.exists(sam2_dir): | |
| print("Installing SAM2...") | |
| os.makedirs("third_party", exist_ok=True) | |
| subprocess.run([ | |
| "git", "clone", | |
| "--recursive", | |
| "https://github.com/facebookresearch/sam2.git", | |
| sam2_dir | |
| ], check=True) | |
| original_dir = os.getcwd() | |
| try: | |
| os.chdir(sam2_dir) | |
| subprocess.run(["pip", "install", "-e", "."], check=True) | |
| except Exception as e: | |
| print(f"Error during SAM2 installation: {str(e)}") | |
| raise | |
| finally: | |
| os.chdir(original_dir) | |
| print("✅ SAM2 installed successfully!") | |
| else: | |
| print("SAM2 already exists, skipping installation.") | |
| install_sam2() | |
| sam2_dir = add_sam2_to_path() | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| print("🎉 SAM2 modules imported successfully!") | |
| MODEL_PATH = "geshang/Seg-R1-7B" | |
| SAM_CHECKPOINT = "sam2_weights/sam2.1_hiera_large.pt" | |
| DEVICE = "cuda" #if torch.cuda.is_available() else "cpu" | |
| RESIZE_SIZE = (1024, 1024) | |
| try: | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True) | |
| print(f"Qwen model loaded on {DEVICE}") | |
| except Exception as e: | |
| print(f"Error loading Qwen model: {e}") | |
| model = None | |
| processor = None | |
| # SAM Wrapper | |
| class CustomSAMWrapper: | |
| def __init__(self, model_path: str, device: str = DEVICE): | |
| # try: | |
| self.device = torch.device(device) | |
| sam_model = build_sam2("configs/sam2.1/sam2.1_hiera_l.yaml", model_path, self.device) | |
| sam_model = sam_model.to(self.device) | |
| self.predictor = SAM2ImagePredictor(sam_model) | |
| self.last_mask = None | |
| print(f"SAM model loaded on {device}") | |
| # except Exception as e: | |
| # print(f"Error loading SAM model: {e}") | |
| # self.predictor = None | |
| def predict(self, image: PILImage.Image, | |
| points: List[Tuple[int, int]], | |
| labels: List[int], | |
| bbox: Optional[List[List[int]]] = None) -> Tuple[np.ndarray, float]: | |
| if not self.predictor: | |
| return np.zeros((image.height, image.width), dtype=bool), 0.0 | |
| try: | |
| input_points = np.array(points) if points else None | |
| input_labels = np.array(labels) if labels else None | |
| input_bboxes = np.array(bbox) if bbox else None | |
| image_np = np.array(image) | |
| rgb_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
| self.predictor.set_image(rgb_image) | |
| mask_pred, score, logits = self.predictor.predict( | |
| point_coords=input_points, | |
| point_labels=input_labels, | |
| box=input_bboxes, | |
| multimask_output=False, | |
| ) | |
| self.last_mask = mask_pred[0] | |
| return mask_pred[0], score[0] | |
| except Exception as e: | |
| print(f"SAM prediction error: {e}") | |
| return np.zeros((image.height, image.width), dtype=bool), 0.0 | |
| def parse_custom_format(content: str): | |
| point_pattern = r"<points>\s*(\[\s*(?:\[\s*\d+\s*,\s*\d+\s*\]\s*,?\s*)+\])\s*</points>" | |
| label_pattern = r"<labels>\s*(\[\s*(?:\d+\s*,?\s*)+\])\s*</labels>" | |
| bbox_pattern = r"<bbox>\s*(\[\s*\d+\s*,\s*\d+\s*,\s*\d+\s*,\s*\d+\s*\])\s*</bbox>" | |
| point_match = re.search(point_pattern, content) | |
| label_match = re.search(label_pattern, content) | |
| bbox_matches = re.findall(bbox_pattern, content) | |
| try: | |
| points = np.array(eval(point_match.group(1))) if point_match else None | |
| labels = np.array(eval(label_match.group(1))) if label_match else None | |
| if points is not None and labels is not None: | |
| if not (len(points.shape) == 2 and points.shape[1] == 2 and len(labels) == points.shape[0]): | |
| points, labels = None, None | |
| bboxes = [] | |
| for bbox_str in bbox_matches: | |
| bbox = np.array(eval(bbox_str)) | |
| if len(bbox.shape) == 1 and bbox.shape[0] == 4: | |
| bboxes.append(bbox) | |
| bboxes = np.stack(bboxes, axis=0) if bboxes else None | |
| return points, labels, bboxes | |
| except Exception as e: | |
| print("Error parsing content:", e) | |
| return None, None, None | |
| def prepare_test_messages(image, prompt): | |
| buffered = io.BytesIO() | |
| image = TF.resize(image, RESIZE_SIZE) | |
| image.save(buffered, format="JPEG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| SYSTEM_PROMPT = ( | |
| "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " | |
| "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " | |
| "process should enclosed within <think> </think> tags, and the bounding box, points and points labels should be enclosed within <bbox></bbox>, <points></points>, and <labels></labels>, respectively. i.e., " | |
| "<think> reasoning process here </think> <bbox>[x1,y1,x2,y2]</bbox>, <points>[[x3,y3],[x4,y4],...]</points>, <labels>[1,0,...]</labels>" | |
| "Where 1 indicates a foreground (object) point, and 0 indicates a background point." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": f"data:image/jpeg;base64,{img_base64}"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| return [messages] | |
| def answer_question(batch_messages): | |
| if not model or not processor: | |
| return ["Model not loaded. Please check logs."] | |
| try: | |
| text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] | |
| image_inputs, video_inputs = process_vision_info(batch_messages) | |
| inputs = processor(text=text, images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True).to(DEVICE) | |
| outputs = model.generate(**inputs, use_cache=True, max_new_tokens=1024) | |
| trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)] | |
| return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| except Exception as e: | |
| print(f"Error generating answer: {e}") | |
| return ["Error generating response"] | |
| def visualize_masks_on_image( | |
| image: PILImage.Image, | |
| masks_np: list, | |
| colors=[(255, 0, 0), (0, 255, 0), (0, 0, 255), | |
| (255, 255, 0), (255, 0, 255), (0, 255, 255), | |
| (128, 128, 255)], | |
| alpha=0.5, | |
| ): | |
| if not masks_np: | |
| return image | |
| image_np = np.array(image) | |
| color_mask = np.zeros((image_np.shape[0], image_np.shape[1], 3), dtype=np.uint8) | |
| for i, mask in enumerate(masks_np): | |
| color = colors[i % len(colors)] | |
| mask = mask.astype(np.uint8) | |
| if mask.shape[:2] != image_np.shape[:2]: | |
| mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) | |
| color_mask[:, :, 0] = color_mask[:, :, 0] | (mask * color[0]) | |
| color_mask[:, :, 1] = color_mask[:, :, 1] | (mask * color[1]) | |
| color_mask[:, :, 2] = color_mask[:, :, 2] | (mask * color[2]) | |
| blended = cv2.addWeighted(image_np, 1 - alpha, color_mask, alpha, 0) | |
| return PILImage.fromarray(blended) | |
| def run_pipeline(image: PILImage.Image, prompt: str): | |
| sam_wrapper = CustomSAMWrapper(SAM_CHECKPOINT, device=DEVICE) | |
| if not model or not processor: | |
| return "Models not loaded. Please check logs.", None | |
| try: | |
| img_original = image.copy() | |
| img_resized = TF.resize(image, RESIZE_SIZE) | |
| messages = prepare_test_messages(img_resized, prompt) | |
| output_text = answer_question(messages)[0] | |
| print(f"Model output: {output_text}") | |
| points, labels, bbox = parse_custom_format(output_text) | |
| mask_pred = None | |
| final_mask = np.zeros(RESIZE_SIZE[::-1], dtype=bool) | |
| if (points is not None and labels is not None) or (bbox is not None): | |
| img = img_resized | |
| if bbox is not None and len(bbox.shape) == 2: | |
| for b in bbox: | |
| b = b.tolist() | |
| if points is not None and labels is not None: | |
| in_bbox_mask = ( | |
| (points[:, 0] >= b[0]) & (points[:, 0] <= b[2]) & | |
| (points[:, 1] >= b[1]) & (points[:, 1] <= b[3]) | |
| ) | |
| selected_points = points[in_bbox_mask] | |
| selected_labels = labels[in_bbox_mask] | |
| else: | |
| selected_points, selected_labels = None, None | |
| try: | |
| mask, _ = sam_wrapper.predict( | |
| img, | |
| selected_points.tolist() if selected_points is not None and len(selected_points) > 0 else None, | |
| selected_labels.tolist() if selected_labels is not None and len(selected_labels) > 0 else None, | |
| b | |
| ) | |
| final_mask |= (mask > 0) | |
| except Exception as e: | |
| print(f"Mask prediction error for bbox: {e}") | |
| continue | |
| mask_pred = final_mask | |
| else: | |
| try: | |
| mask_pred, _ = sam_wrapper.predict( | |
| img, | |
| points.tolist() if points is not None else None, | |
| labels.tolist() if labels is not None else None, | |
| bbox.tolist() if bbox is not None else None | |
| ) | |
| mask_pred = mask_pred > 0 | |
| except Exception as e: | |
| print(f"Mask prediction error: {e}") | |
| mask_pred = np.zeros(RESIZE_SIZE[::-1], dtype=bool) | |
| else: | |
| return output_text, None | |
| mask_np = mask_pred | |
| mask_img = PILImage.fromarray((mask_np * 255).astype(np.uint8)).resize(img_original.size) | |
| mask_img = mask_img.convert("L") | |
| mask_np = np.array(mask_img) > 128 | |
| visualized_img = visualize_masks_on_image( | |
| img_original, | |
| masks_np=[mask_np], | |
| alpha=0.6 | |
| ) | |
| match = re.search(r'(<think>.*?</think>)', output_text, re.DOTALL) | |
| if match: | |
| output_text = match.group(1) | |
| return output_text, visualized_img | |
| except Exception as e: | |
| print(f"Pipeline error: {e}") | |
| return f"Error processing request: {str(e)}", None | |
| def load_description(fp): | |
| with open(fp, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| with gr.Blocks(title="Seg-R1") as demo: | |
| # gr.Markdown("# Seg-R1") | |
| # gr.Markdown("Upload an image and ask questions about segmentation.") | |
| gr.HTML(load_description("assets/title.md")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox(lines=2, label="Question", placeholder="Ask about objects in the image...") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Model Response", interactive=False) | |
| image_output = gr.Image(type="pil", label="Segmentation Result", interactive=False) | |
| submit_btn.click( | |
| fn=run_pipeline, | |
| inputs=[image_input, text_input], | |
| outputs=[text_output, image_output] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["imgs/camourflage1.jpg", "There is a creature hidden in its surroundings, segment it."], | |
| ["imgs/camourflage2.jpg", "Please segment the camouflaged object in this image."], | |
| ["imgs/dog_in_sheeps.jpg", "Find the one that suffers."], | |
| ["imgs/kind_lady.jpg", "Find the most uncommon part of this picture."], | |
| ["imgs/painting.jpg", "Identify and segment the man and the sky."], | |
| ["imgs/man_and_cat.jpg", "Identify and segment the cat and the glasses of the man."], | |
| ], | |
| inputs=[image_input, text_input], | |
| outputs=[text_output, image_output], | |
| fn=run_pipeline, | |
| cache_examples=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |