city / app.py
ksh123k's picture
Update app.py
f137102 verified
raw
history blame
4.44 kB
import gradio as gr
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
# ==============================
# βœ… λͺ¨λΈ λ‘œλ“œ
# ==============================
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
# ==============================
# βœ… νŒ”λ ˆνŠΈ / 라벨 λ‘œλ“œ
# ==============================
def city_palette():
return [
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
[153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]
]
colormap = np.asarray(city_palette(), dtype=np.uint8)
labels_list = [l.strip() for l in open("labels.txt", "r", encoding="utf-8").readlines()]
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("Label value too large.")
return colormap[label]
# ==============================
# βœ… μ‹œκ°ν™” ν•¨μˆ˜
# ==============================
def draw_plot(pred_img, seg_np):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis('off')
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg_np.astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=22)
plt.tight_layout()
return fig
# ==============================
# βœ… μΆ”λ‘  ν•¨μˆ˜
# ==============================
def run_inference(input_img):
img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
if img.mode != "RGB":
img = img.convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits, size=img.size[::-1], mode="bilinear", align_corners=False
)
seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
color_seg = colormap[seg]
pred_img = (np.array(img) * 0.5 + color_seg * 0.5).astype(np.uint8)
return draw_plot(pred_img, seg)
# ==============================
# πŸŒ† Gradio UI ꡬ성
# ==============================
with gr.Blocks(theme=gr.themes.Soft(), css="""
#title {text-align:center; font-size:2.2em; font-weight:700; margin-bottom:0.5em;}
#desc {text-align:center; color:#555; font-size:1.1em; margin-bottom:2em;}
#footer {text-align:center; color:#888; margin-top:2em; font-size:0.9em;}
""") as demo:
gr.Markdown("<div id='title'>πŸ™οΈ City Segmenter</div>")
gr.Markdown("<div id='desc'>SegFormer-B2 기반 λ„μ‹œ 이미지 μ„Έκ·Έλ©˜ν…Œμ΄μ…˜.<br>λ„λ‘œ, 건물, μ°¨λŸ‰, μ‚¬λžŒ 등을 μƒ‰μƒμœΌλ‘œ λΆ„λ¦¬ν•˜μ—¬ μ‹œκ°ν™”ν•©λ‹ˆλ‹€.</div>")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="numpy", label="πŸ“€ 이미지 μ—…λ‘œλ“œ")
example_box = gr.Examples(
examples=[
"city1.jpeg",
"city2.jpg",
"city3.jpg"
],
inputs=input_img,
label="πŸ–ΌοΈ μ˜ˆμ‹œ 이미지"
)
run_btn = gr.Button("πŸš€ 뢄석 μ‹œμž‘", variant="primary")
with gr.Column(scale=2):
output_plot = gr.Plot(label="🧭 κ²°κ³Ό μ‹œκ°ν™” (Overlay + Legend)")
run_btn.click(fn=run_inference, inputs=input_img, outputs=output_plot)
gr.Markdown("<div id='footer'>Β© 2025 City Segmenter β€’ Powered by NVIDIA SegFormer-B2</div>")
# ==============================
if __name__ == "__main__":
demo.launch()