import gradio as gr from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation import torch from PIL import Image import numpy as np import matplotlib.pyplot as plt # ✅ 1. 모델 로드 MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024" processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) # ✅ 2. Cityscapes 팔레트 (19 classes) CITYSCAPES_COLORS = np.array([ [128, 64,128], [244, 35,232], [ 70, 70, 70], [102,102,156], [190,153,153], [153,153,153], [250,170, 30], [220,220, 0], [107,142, 35], [152,251,152], [ 70,130,180], [220, 20, 60], [255, 0, 0], [ 0, 0,142], [ 0, 0, 70], [ 0, 60,100], [ 0, 80,100], [ 0, 0,230], [119, 11, 32] ]) CITYSCAPES_LABELS = [ "road","sidewalk","building","wall","fence","pole","traffic light", "traffic sign","vegetation","terrain","sky","person","rider","car", "truck","bus","train","motorcycle","bicycle" ] # ✅ 3. 세그멘테이션 함수 def segment_cityscape(input_image): image = input_image.convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False ) seg = upsampled.argmax(dim=1)[0].detach().cpu().numpy() seg_color = CITYSCAPES_COLORS[seg % len(CITYSCAPES_COLORS)] seg_image = Image.fromarray(seg_color.astype(np.uint8)) blended = Image.blend(image, seg_image, alpha=0.6) # legend 시각화용 fig, ax = plt.subplots(figsize=(3,6)) unique_labels = np.unique(seg) for i, label in enumerate(unique_labels): ax.barh(i, 1, color=CITYSCAPES_COLORS[label]/255) ax.set_yticks(range(len(unique_labels))) ax.set_yticklabels([CITYSCAPES_LABELS[i] for i in unique_labels], fontsize=10) ax.invert_yaxis() ax.set_xticks([]) ax.set_title("Color Legend", fontsize=12) plt.tight_layout() return image, seg_image, blended, fig # ✅ 4. Gradio 인터페이스 (창의적 UI) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.HTML( """
도시 풍경 이미지에서 도로, 차량, 사람, 하늘 등 19가지 클래스를 자동으로 분할합니다.
""" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📸 입력") input_img = gr.Image(type="pil", label="도시 거리 이미지를 업로드하세요") run_btn = gr.Button("🚀 세그멘테이션 실행", variant="primary") with gr.Column(scale=2): with gr.Tab("원본"): orig = gr.Image(label="Original") with gr.Tab("세그멘테이션 결과"): mask = gr.Image(label="Segmentation Mask") with gr.Tab("오버레이(합성)"): overlay = gr.Image(label="Overlay Result") with gr.Tab("색상 범례"): legend_plot = gr.Plot(label="Color Legend") run_btn.click(segment_cityscape, inputs=input_img, outputs=[orig, mask, overlay, legend_plot]) with gr.Accordion("🎞️ 예시 이미지로 테스트", open=False): gr.Examples( examples=[ ["city1.jpeg"], ["city2.jpg"] ], inputs=input_img ) gr.HTML( """🔹 Model: nvidia/segformer-b2-finetuned-cityscapes-1024-1024
🔹 Trained on: Cityscapes Dataset (urban outdoor scenes)