object_remover / src /core.py
LogicGoInfotechSpaces's picture
Update src/core.py
961280b verified
import logging
import os
from io import BytesIO
# Load environment variables from .env if present (helps local dev)
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
import base64
import cv2
import numpy as np
from PIL import Image
import google.generativeai as genai
log = logging.getLogger(__name__)
# Remote inference configuration (Gemini API key only; no Vertex required)
DEFAULT_MODEL_ID = os.environ.get("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image")
DEFAULT_PROMPT = os.environ.get(
"GEMINI_IMAGE_PROMPT",
(
"TASK TYPE: STRICT IMAGE INPAINTING — OBJECT REMOVAL ONLY\n\n"
"You are given:\n"
"1) An original image\n"
"2) A binary mask image\n\n"
"MASK RULE (MANDATORY):\n"
"• White pixels (#FFFFFF) indicate the exact region to be REMOVED.\n"
"• Black pixels (#000000) indicate regions that MUST remain completely unchanged.\n\n"
"PRIMARY OBJECTIVE:\n"
"Completely delete everything inside the white masked area.\n"
"The object in the white region must be fully removed with no visible remnants,\n"
"no partial shapes, no outlines, no shadows, and no color traces.\n\n"
"INPAINTING INSTRUCTIONS:\n"
"Ignore the content of the white masked area entirely.\n"
"Reconstruct that region using ONLY surrounding background information.\n"
"Extend nearby background textures, patterns, and structures naturally.\n"
"Match lighting direction, brightness, contrast, color temperature, and noise.\n"
"Continue edges, lines, and surfaces realistically across the removed area.\n"
"Blend boundaries smoothly so the edit is visually undetectable.\n\n"
"Return only the edited image.\n\n"
"STRICT CONSTRAINTS:\n"
"• Do NOT generate or keep any part of the removed object.\n"
"• Do NOT invent new objects or details.\n"
"• Do NOT repaint, modify, blur, or enhance any black (unmasked) area.\n"
"• Do NOT change the original image composition.\n"
"• Do NOT change camera angle, perspective, or scale.\n\n"
"QUALITY REQUIREMENTS:\n"
"• No ghosting or transparent object remains.\n"
"• No edge halos or smearing.\n"
"• No repeated textures or patchy fills.\n"
"• Result must look like the object never existed.\n\n"
"FAILURE CONDITIONS (MUST BE AVOIDED):\n"
"If any object fragment, outline, shadow, or color from the removed object\n"
"is still visible, the result is incorrect and must be re-generated."
),
)
_GENAI_MODEL: genai.GenerativeModel | None = None
def _resize_mask(mask: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray:
"""Resize mask to match the target height/width."""
target_h, target_w = target_hw
if mask.shape[:2] == (target_h, target_w):
return mask
return cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
def _binary_mask_from_rgba(mask: np.ndarray, invert_mask: bool) -> np.ndarray:
"""
Normalize incoming RGBA masks to a 0/255 binary mask.
- Transparent alpha (0) is treated as "remove"
- White/bright RGB is treated as "remove" when alpha is mostly opaque
"""
if mask.shape[2] == 3:
alpha_channel = np.ones(mask.shape[:2], dtype=np.uint8) * 255
rgb_channels = mask
else:
alpha_channel = mask[:, :, 3]
rgb_channels = mask[:, :, :3]
# If alpha carries information, prefer it
if alpha_channel.mean() < 240:
mask_bw = np.where(alpha_channel < 128, 255, 0).astype(np.uint8)
else:
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
mask_bw = np.where(gray > 128, 255, 0).astype(np.uint8)
if not invert_mask:
mask_bw = 255 - mask_bw
return mask_bw
def _pil_to_png_bytes(img: Image.Image) -> bytes:
"""Encode a PIL image to PNG bytes for Gemini edit endpoints."""
buf = BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return buf.getvalue()
def _get_gemini_model() -> genai.GenerativeModel:
global _GENAI_MODEL
if _GENAI_MODEL is None:
api_key = (
os.environ.get("GEMINI_API_KEY")
or os.environ.get("GOOGLE_API_KEY")
or os.environ.get("GOOGLE_GENAI_API_KEY")
)
if not api_key:
raise RuntimeError("Set Gemini API key via GEMINI_API_KEY / GOOGLE_API_KEY / GOOGLE_GENAI_API_KEY")
genai.configure(api_key=api_key)
model_id = os.environ.get("GEMINI_IMAGE_MODEL", DEFAULT_MODEL_ID)
_GENAI_MODEL = genai.GenerativeModel(model_id)
return _GENAI_MODEL
def _call_gemini_edit(
image_rgb: np.ndarray,
mask_bw: np.ndarray,
prompt: str | None,
target_size: tuple[int, int],
) -> Image.Image:
"""
Send source image + binary mask to Gemini via API-key-only generate_content.
We include both the base image and the mask as separate parts and instruct the model to remove masked regions.
"""
model = _get_gemini_model()
base_image = Image.fromarray(image_rgb).convert("RGB")
mask_image = Image.fromarray(mask_bw).convert("L")
# Build a guidance image where the removal region is painted white for clarity
guidance_rgb = image_rgb.copy()
guidance_rgb[mask_bw > 0] = 255
guidance_image = Image.fromarray(guidance_rgb).convert("RGB")
base_bytes = _pil_to_png_bytes(base_image)
mask_bytes = _pil_to_png_bytes(mask_image)
guidance_bytes = _pil_to_png_bytes(guidance_image)
# Enrich prompt to explicitly describe the two images being sent
effective_prompt = (
(prompt or DEFAULT_PROMPT).strip()
+ "\nIMAGE ORDER:\n"
+ "Image A: Original photo with the removal region painted white.\n"
+ "Image B: Binary mask (white=remove, black=keep). Use this mask to decide what to remove.\n"
)
log.info(
"Calling Gemini generate_content model=%s (mask-guided remove) mask_pixels=%d",
model.model_name,
int((mask_bw > 0).sum()),
)
# Build content parts: prompt + guidance image + mask image (explicit order)
content = [
effective_prompt,
{"mime_type": "image/png", "data": guidance_bytes},
{"mime_type": "image/png", "data": mask_bytes},
]
# Note: response_mime_type doesn't support image/png in the old google.generativeai package
# Images are returned in response parts as inline_data
try:
response = model.generate_content(
content,
stream=False
)
except Exception as gen_err:
log.error("Gemini generate_content raised exception: %s", gen_err, exc_info=True)
raise RuntimeError(f"Gemini API error: {gen_err}")
output_img: Image.Image | None = None
# Check for blocked content or errors
candidates = getattr(response, "candidates", [])
if not candidates:
log.error("Gemini returned no candidates")
raise RuntimeError("Gemini API returned no candidates. The request may have been blocked.")
# Check finish_reason for blocked content
for idx, candidate in enumerate(candidates):
finish_reason = getattr(candidate, "finish_reason", None)
if finish_reason:
# finish_reason values: 0=STOP, 1=MAX_TOKENS, 2=SAFETY, 3=RECITATION, 4=OTHER, 17=BLOCKED
if finish_reason == 17 or finish_reason == 2:
safety_ratings = getattr(candidate, "safety_ratings", [])
log.error("Gemini blocked the request. Finish reason: %s, Safety ratings: %s", finish_reason, safety_ratings)
raise RuntimeError(f"Gemini API blocked the content (finish_reason={finish_reason}). The image may violate safety policies.")
elif finish_reason != 0: # 0 = STOP (normal completion)
log.warning("Gemini finished with non-zero reason: %s", finish_reason)
# Extract first image from response parts
try:
log.debug("Number of candidates: %d", len(candidates))
for idx, candidate in enumerate(candidates):
parts = getattr(candidate, "content", None)
if not parts:
log.debug("Candidate %d has no content", idx)
continue
response_parts = getattr(parts, "parts", None)
if not response_parts:
log.debug("Candidate %d content has no parts", idx)
continue
log.debug("Candidate %d has %d parts", idx, len(response_parts))
for part_idx, part in enumerate(response_parts):
inline = getattr(part, "inline_data", None)
if inline:
log.debug("Part %d has inline_data, mime_type: %s", part_idx, getattr(inline, "mime_type", None))
if inline.data:
data = inline.data
if isinstance(data, str):
data = base64.b64decode(data)
output_img = Image.open(BytesIO(data)).convert("RGB")
log.info("Successfully extracted image from Gemini response")
break
else:
# Check if part has text (might be an error message)
text = getattr(part, "text", None)
if text:
log.warning("Gemini returned text instead of image in part %d: %s", part_idx, text[:200])
if output_img:
break
except Exception as err:
log.error("Failed to parse Gemini response image: %s", err, exc_info=True)
if output_img is None:
# Log full response for debugging
try:
response_text = str(response)
log.error("Gemini generate_content returned no image. Full response (first 1000 chars): %s", response_text[:1000])
# Try to extract any error messages
if hasattr(response, "prompt_feedback"):
feedback = response.prompt_feedback
log.error("Prompt feedback: %s", feedback)
# Check candidates for finish reasons
for idx, candidate in enumerate(candidates):
finish_reason = getattr(candidate, "finish_reason", None)
log.error("Candidate %d finish_reason: %s", idx, finish_reason)
except Exception:
pass
raise RuntimeError("Gemini generate_content returned no image. Check logs for details.")
# Ensure output matches original dimensions if Gemini rescaled
if output_img.size != target_size:
output_img = output_img.resize(target_size, Image.Resampling.LANCZOS)
return output_img
def process_inpaint(
image: np.ndarray,
mask: np.ndarray,
invert_mask: bool = True,
prompt: str | None = None,
) -> np.ndarray:
"""
Forward inpainting to Gemini edit API using source image + mask.
"""
image_rgba = Image.fromarray(image).convert("RGBA")
image_rgb = np.array(image_rgba.convert("RGB"))
mask_rgba = np.array(Image.fromarray(mask).convert("RGBA"))
mask_bw = _binary_mask_from_rgba(mask_rgba, invert_mask)
mask_bw = _resize_mask(mask_bw, image_rgb.shape[:2])
target_size = (image_rgb.shape[1], image_rgb.shape[0]) # (width, height)
edited_image = _call_gemini_edit(image_rgb, mask_bw, prompt, target_size)
return np.array(edited_image)