import gradio as gr import matplotlib from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation import time # 모델 로드 MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024" processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) def ade_palette(): """ADE20K palette that maps each class to RGB values.""" return [ [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32] ] # labels.txt 파일 읽기 labels_list = [] with open("labels.txt", "r", encoding="utf-8") as fp: for line in fp: labels_list.append(line.rstrip("\n")) colormap = np.asarray(ade_palette(), dtype=np.uint8) def label_to_color_image(label): if label.ndim != 2: raise ValueError("Expect 2-D input label") if np.max(label) >= len(colormap): raise ValueError("label value too large.") return colormap[label] # ✅ [수정됨] : figsize와 width_ratios를 늘려서 이미지를 크게 만듦 def draw_plot(pred_img, seg_np): # Figure의 전체 크기를 (25, 20)으로 늘림 fig = plt.figure(figsize=(25, 20)) # 이미지와 범례의 너비 비율을 8:1로 변경 (이미지가 더 넓어짐) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[8, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis('off') LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg_np.astype("uint8")) valid_labels = [label for label in unique_labels if label < len(LABEL_NAMES)] ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[valid_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(valid_labels)), LABEL_NAMES[valid_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig # ✅ [수정됨] : 'alpha' 파라미터를 받는 '슬라이더 버전' def run_inference(input_img, alpha=0.5): start_time = time.time() img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img if img.mode != "RGB": img = img.convert("RGB") inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate( logits, size=img.size[::-1], mode="bilinear", align_corners=False ) seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) color_seg = colormap[seg] # alpha 변수를 사용해 투명도 조절 image_weight = 1.0 - alpha overlay_weight = alpha pred_img = (np.array(img) * image_weight + color_seg * overlay_weight).astype(np.uint8) fig = draw_plot(pred_img, seg) print(f"Inference time: {time.time() - start_time:.2f}s") return fig # 다크 테마 정의 custom_theme = gr.themes.Soft( primary_hue="emerald", # 메인 색상: 청록빛 초록 secondary_hue="teal", # 보조 색상: 진한 청록 neutral_hue="slate" # 기본 톤 유지 (어두운 회색계열) ).set( body_background_fill="#0f172a", # 어두운 배경 유지 (다크모드) body_text_color="#e2f1e8", # 살짝 초록빛이 도는 밝은 텍스트 button_primary_background_fill="#10b981", # 메인 버튼색 (emerald-500) button_primary_text_color="#ffffff", # 버튼 안 글자색 (흰색) block_background_fill="#1a2e25", # 블록 영역: 짙은 녹색 톤 배경 ) demo = gr.Interface( fn=run_inference, # ✅ [수정됨] : inputs에 슬라이더 다시 추가 inputs=[ gr.Image(type="numpy", label="📸 Input Image"), gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency (투명도)") ], outputs=gr.Plot(label="Overlay + Legend"), # ✅ [수정됨] : examples를 중첩 리스트로 변경 examples=[ ["city1.png", 0.5], ["city2.png", 0.5], ["city3.jpg", 0.5], ["city4.jpeg", 0.5], ["city5.jpg", 0.5] ], flagging_mode="never", cache_examples=False, title="🏙️ City Segment", description=( "segformer-b2모델을 이용 도시 이미지 분할 시각.
" "이미지를 업로드하면 도로, 건물, 차량, 사람 등 객체별로 색상으로 구분해줍니다." ), theme=custom_theme ) if __name__ == "__main__": demo.launch()