|
|
import logging |
|
|
import os |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
try: |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
import base64 |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import google.generativeai as genai |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
DEFAULT_MODEL_ID = os.environ.get("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") |
|
|
DEFAULT_PROMPT = os.environ.get( |
|
|
"GEMINI_IMAGE_PROMPT", |
|
|
( |
|
|
"TASK TYPE: STRICT IMAGE INPAINTING — OBJECT REMOVAL ONLY\n\n" |
|
|
"You are given:\n" |
|
|
"1) An original image\n" |
|
|
"2) A binary mask image\n\n" |
|
|
"MASK RULE (MANDATORY):\n" |
|
|
"• White pixels (#FFFFFF) indicate the exact region to be REMOVED.\n" |
|
|
"• Black pixels (#000000) indicate regions that MUST remain completely unchanged.\n\n" |
|
|
"PRIMARY OBJECTIVE:\n" |
|
|
"Completely delete everything inside the white masked area.\n" |
|
|
"The object in the white region must be fully removed with no visible remnants,\n" |
|
|
"no partial shapes, no outlines, no shadows, and no color traces.\n\n" |
|
|
"INPAINTING INSTRUCTIONS:\n" |
|
|
"Ignore the content of the white masked area entirely.\n" |
|
|
"Reconstruct that region using ONLY surrounding background information.\n" |
|
|
"Extend nearby background textures, patterns, and structures naturally.\n" |
|
|
"Match lighting direction, brightness, contrast, color temperature, and noise.\n" |
|
|
"Continue edges, lines, and surfaces realistically across the removed area.\n" |
|
|
"Blend boundaries smoothly so the edit is visually undetectable.\n\n" |
|
|
"Return only the edited image.\n\n" |
|
|
"STRICT CONSTRAINTS:\n" |
|
|
"• Do NOT generate or keep any part of the removed object.\n" |
|
|
"• Do NOT invent new objects or details.\n" |
|
|
"• Do NOT repaint, modify, blur, or enhance any black (unmasked) area.\n" |
|
|
"• Do NOT change the original image composition.\n" |
|
|
"• Do NOT change camera angle, perspective, or scale.\n\n" |
|
|
"QUALITY REQUIREMENTS:\n" |
|
|
"• No ghosting or transparent object remains.\n" |
|
|
"• No edge halos or smearing.\n" |
|
|
"• No repeated textures or patchy fills.\n" |
|
|
"• Result must look like the object never existed.\n\n" |
|
|
"FAILURE CONDITIONS (MUST BE AVOIDED):\n" |
|
|
"If any object fragment, outline, shadow, or color from the removed object\n" |
|
|
"is still visible, the result is incorrect and must be re-generated." |
|
|
), |
|
|
) |
|
|
_GENAI_MODEL: genai.GenerativeModel | None = None |
|
|
|
|
|
|
|
|
def _resize_mask(mask: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray: |
|
|
"""Resize mask to match the target height/width.""" |
|
|
target_h, target_w = target_hw |
|
|
if mask.shape[:2] == (target_h, target_w): |
|
|
return mask |
|
|
return cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
|
|
|
def _binary_mask_from_rgba(mask: np.ndarray, invert_mask: bool) -> np.ndarray: |
|
|
""" |
|
|
Normalize incoming RGBA masks to a 0/255 binary mask. |
|
|
- Transparent alpha (0) is treated as "remove" |
|
|
- White/bright RGB is treated as "remove" when alpha is mostly opaque |
|
|
""" |
|
|
if mask.shape[2] == 3: |
|
|
alpha_channel = np.ones(mask.shape[:2], dtype=np.uint8) * 255 |
|
|
rgb_channels = mask |
|
|
else: |
|
|
alpha_channel = mask[:, :, 3] |
|
|
rgb_channels = mask[:, :, :3] |
|
|
|
|
|
|
|
|
if alpha_channel.mean() < 240: |
|
|
mask_bw = np.where(alpha_channel < 128, 255, 0).astype(np.uint8) |
|
|
else: |
|
|
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY) |
|
|
mask_bw = np.where(gray > 128, 255, 0).astype(np.uint8) |
|
|
|
|
|
if not invert_mask: |
|
|
mask_bw = 255 - mask_bw |
|
|
|
|
|
return mask_bw |
|
|
|
|
|
|
|
|
def _pil_to_png_bytes(img: Image.Image) -> bytes: |
|
|
"""Encode a PIL image to PNG bytes for Gemini edit endpoints.""" |
|
|
buf = BytesIO() |
|
|
img.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
return buf.getvalue() |
|
|
|
|
|
|
|
|
def _get_gemini_model() -> genai.GenerativeModel: |
|
|
global _GENAI_MODEL |
|
|
if _GENAI_MODEL is None: |
|
|
api_key = ( |
|
|
os.environ.get("GEMINI_API_KEY") |
|
|
or os.environ.get("GOOGLE_API_KEY") |
|
|
or os.environ.get("GOOGLE_GENAI_API_KEY") |
|
|
) |
|
|
if not api_key: |
|
|
raise RuntimeError("Set Gemini API key via GEMINI_API_KEY / GOOGLE_API_KEY / GOOGLE_GENAI_API_KEY") |
|
|
genai.configure(api_key=api_key) |
|
|
model_id = os.environ.get("GEMINI_IMAGE_MODEL", DEFAULT_MODEL_ID) |
|
|
_GENAI_MODEL = genai.GenerativeModel(model_id) |
|
|
return _GENAI_MODEL |
|
|
|
|
|
|
|
|
def _call_gemini_edit( |
|
|
image_rgb: np.ndarray, |
|
|
mask_bw: np.ndarray, |
|
|
prompt: str | None, |
|
|
target_size: tuple[int, int], |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Send source image + binary mask to Gemini via API-key-only generate_content. |
|
|
We include both the base image and the mask as separate parts and instruct the model to remove masked regions. |
|
|
""" |
|
|
model = _get_gemini_model() |
|
|
|
|
|
base_image = Image.fromarray(image_rgb).convert("RGB") |
|
|
mask_image = Image.fromarray(mask_bw).convert("L") |
|
|
|
|
|
|
|
|
guidance_rgb = image_rgb.copy() |
|
|
guidance_rgb[mask_bw > 0] = 255 |
|
|
guidance_image = Image.fromarray(guidance_rgb).convert("RGB") |
|
|
|
|
|
base_bytes = _pil_to_png_bytes(base_image) |
|
|
mask_bytes = _pil_to_png_bytes(mask_image) |
|
|
guidance_bytes = _pil_to_png_bytes(guidance_image) |
|
|
|
|
|
|
|
|
effective_prompt = ( |
|
|
(prompt or DEFAULT_PROMPT).strip() |
|
|
+ "\nIMAGE ORDER:\n" |
|
|
+ "Image A: Original photo with the removal region painted white.\n" |
|
|
+ "Image B: Binary mask (white=remove, black=keep). Use this mask to decide what to remove.\n" |
|
|
) |
|
|
log.info( |
|
|
"Calling Gemini generate_content model=%s (mask-guided remove) mask_pixels=%d", |
|
|
model.model_name, |
|
|
int((mask_bw > 0).sum()), |
|
|
) |
|
|
|
|
|
|
|
|
content = [ |
|
|
effective_prompt, |
|
|
{"mime_type": "image/png", "data": guidance_bytes}, |
|
|
{"mime_type": "image/png", "data": mask_bytes}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
response = model.generate_content( |
|
|
content, |
|
|
stream=False |
|
|
) |
|
|
except Exception as gen_err: |
|
|
log.error("Gemini generate_content raised exception: %s", gen_err, exc_info=True) |
|
|
raise RuntimeError(f"Gemini API error: {gen_err}") |
|
|
|
|
|
output_img: Image.Image | None = None |
|
|
|
|
|
|
|
|
candidates = getattr(response, "candidates", []) |
|
|
if not candidates: |
|
|
log.error("Gemini returned no candidates") |
|
|
raise RuntimeError("Gemini API returned no candidates. The request may have been blocked.") |
|
|
|
|
|
|
|
|
for idx, candidate in enumerate(candidates): |
|
|
finish_reason = getattr(candidate, "finish_reason", None) |
|
|
if finish_reason: |
|
|
|
|
|
if finish_reason == 17 or finish_reason == 2: |
|
|
safety_ratings = getattr(candidate, "safety_ratings", []) |
|
|
log.error("Gemini blocked the request. Finish reason: %s, Safety ratings: %s", finish_reason, safety_ratings) |
|
|
raise RuntimeError(f"Gemini API blocked the content (finish_reason={finish_reason}). The image may violate safety policies.") |
|
|
elif finish_reason != 0: |
|
|
log.warning("Gemini finished with non-zero reason: %s", finish_reason) |
|
|
|
|
|
|
|
|
try: |
|
|
log.debug("Number of candidates: %d", len(candidates)) |
|
|
|
|
|
for idx, candidate in enumerate(candidates): |
|
|
parts = getattr(candidate, "content", None) |
|
|
if not parts: |
|
|
log.debug("Candidate %d has no content", idx) |
|
|
continue |
|
|
response_parts = getattr(parts, "parts", None) |
|
|
if not response_parts: |
|
|
log.debug("Candidate %d content has no parts", idx) |
|
|
continue |
|
|
log.debug("Candidate %d has %d parts", idx, len(response_parts)) |
|
|
|
|
|
for part_idx, part in enumerate(response_parts): |
|
|
inline = getattr(part, "inline_data", None) |
|
|
if inline: |
|
|
log.debug("Part %d has inline_data, mime_type: %s", part_idx, getattr(inline, "mime_type", None)) |
|
|
if inline.data: |
|
|
data = inline.data |
|
|
if isinstance(data, str): |
|
|
data = base64.b64decode(data) |
|
|
output_img = Image.open(BytesIO(data)).convert("RGB") |
|
|
log.info("Successfully extracted image from Gemini response") |
|
|
break |
|
|
else: |
|
|
|
|
|
text = getattr(part, "text", None) |
|
|
if text: |
|
|
log.warning("Gemini returned text instead of image in part %d: %s", part_idx, text[:200]) |
|
|
if output_img: |
|
|
break |
|
|
except Exception as err: |
|
|
log.error("Failed to parse Gemini response image: %s", err, exc_info=True) |
|
|
|
|
|
if output_img is None: |
|
|
|
|
|
try: |
|
|
response_text = str(response) |
|
|
log.error("Gemini generate_content returned no image. Full response (first 1000 chars): %s", response_text[:1000]) |
|
|
|
|
|
if hasattr(response, "prompt_feedback"): |
|
|
feedback = response.prompt_feedback |
|
|
log.error("Prompt feedback: %s", feedback) |
|
|
|
|
|
for idx, candidate in enumerate(candidates): |
|
|
finish_reason = getattr(candidate, "finish_reason", None) |
|
|
log.error("Candidate %d finish_reason: %s", idx, finish_reason) |
|
|
except Exception: |
|
|
pass |
|
|
raise RuntimeError("Gemini generate_content returned no image. Check logs for details.") |
|
|
|
|
|
|
|
|
if output_img.size != target_size: |
|
|
output_img = output_img.resize(target_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
return output_img |
|
|
|
|
|
|
|
|
def process_inpaint( |
|
|
image: np.ndarray, |
|
|
mask: np.ndarray, |
|
|
invert_mask: bool = True, |
|
|
prompt: str | None = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Forward inpainting to Gemini edit API using source image + mask. |
|
|
""" |
|
|
image_rgba = Image.fromarray(image).convert("RGBA") |
|
|
image_rgb = np.array(image_rgba.convert("RGB")) |
|
|
|
|
|
mask_rgba = np.array(Image.fromarray(mask).convert("RGBA")) |
|
|
mask_bw = _binary_mask_from_rgba(mask_rgba, invert_mask) |
|
|
mask_bw = _resize_mask(mask_bw, image_rgb.shape[:2]) |
|
|
|
|
|
target_size = (image_rgb.shape[1], image_rgb.shape[0]) |
|
|
edited_image = _call_gemini_edit(image_rgb, mask_bw, prompt, target_size) |
|
|
return np.array(edited_image) |
|
|
|