import json import csv import io import base64 import gzip import os import numpy as np import pandas as pd from PIL import Image from io import BytesIO import onnxruntime as rt import huggingface_hub import gradio as gr from huggingface_hub import hf_hub_download for filename in ["classified_tags_danbooru_beta.csv"]: # classified_tags_danbooru_alpha.csv hf_hub_download(repo_id="r3gm/classified_tags", repo_type="dataset", filename=filename, local_dir=".") for filename in ["wai_character_md5.csv", "wai_character_thumbs.json"]: hf_hub_download(repo_id="flagrantia/character_select_stand_alone_app", repo_type="dataset", filename=filename, local_dir=".") # --- Global variables --- _character_md5_map = {} _character_thumbs_data = {} _danbooru_tag_classifier_df = None # Global for the classifier dataframe _preprocessed_allowed_tags_set = set() # Stores tags allowed by current category filter _last_selected_tag_categories = None # To track if categories changed # --- WD14 Tagger Globals --- _wd14_predictor_instance = None # Available models WD14_MODELS = { "WD-SwinV2-V3 (lite)": "SmilingWolf/wd-swinv2-tagger-v3", "WD-ViT-L-V3": "SmilingWolf/wd-vit-large-tagger-v3", } _wd14_selected_model_repo = WD14_MODELS["WD-SwinV2-V3 (lite)"] # Default selected model # Default thresholds and MCut enabled status (can be changed by Gradio inputs) _wd14_general_threshold = 0.35 _wd14_character_threshold = 0.85 _wd14_general_mcut_enabled = False _wd14_character_mcut_enabled = False # --- Tagger model specific file names --- MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] # --- Configuration --- DEFAULT_MAX_DISPLAY_RESULTS = 6 CLASSIFIED_TAGS_CSV = "classified_tags_danbooru_beta.csv" # --- Utility Functions --- def load_character_data_for_app(md5_csv_path='wai_character_md5.csv', thumbs_json_path='wai_character_thumbs.json'): global _character_md5_map, _character_thumbs_data, _danbooru_tag_classifier_df if not _character_md5_map: # Load only if not already loaded try: # Check if file exists in the current directory first, then try sibling directory if not found if not os.path.exists(md5_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', md5_csv_path)): md5_csv_path = os.path.join(os.path.dirname(__file__), '..', md5_csv_path) with open(md5_csv_path, 'r', encoding='utf-8') as csvfile: reader = csv.reader(csvfile) for row in reader: if row and len(row) >= 2: original_name = row[0].strip() md5_hash = row[1].strip() _character_md5_map[original_name.lower()] = {'original_name': original_name, 'md5': md5_hash} except FileNotFoundError: print(f"Error: {md5_csv_path} not found. Please ensure it's in the same directory as the script or a sibling directory.") _character_md5_map = {} # Clear to prevent partial data issues except Exception as e: print(f"Error loading MD5 data: {e}") _character_md5_map = {} if not _character_thumbs_data: # Load only if not already loaded try: # Check if file exists in the current directory first, then try sibling directory if not found if not os.path.exists(thumbs_json_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', thumbs_json_path)): thumbs_json_path = os.path.join(os.path.dirname(__file__), '..', thumbs_json_path) with open(thumbs_json_path, 'r', encoding='utf-8') as jsonfile: _character_thumbs_data = json.load(jsonfile) print(f"Loaded {len(_character_thumbs_data)} thumbnail entries from {thumbs_json_path}.") except FileNotFoundError: print(f"Error: {thumbs_json_path} not found. Please ensure it's in the same directory as the script or a sibling directory.") _character_thumbs_data = {} # Clear to prevent partial data issues except Exception as e: print(f"Error loading thumbnail data: {e}") _character_thumbs_data = {} # Load tag classifier data - assuming it always exists as per requirement try: classifier_csv_path = CLASSIFIED_TAGS_CSV if not os.path.exists(classifier_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', classifier_csv_path)): classifier_csv_path = os.path.join(os.path.dirname(__file__), '..', classifier_csv_path) _danbooru_tag_classifier_df = pd.read_csv(classifier_csv_path, index_col='name') # Drop 'tag_id' if it exists and is not needed for filtering if 'tag_id' in _danbooru_tag_classifier_df.columns: _danbooru_tag_classifier_df = _danbooru_tag_classifier_df.drop(columns=['tag_id']) except Exception as e: print(f"CRITICAL ERROR: Failed to load {CLASSIFIED_TAGS_CSV}. Tag filtering by classification will be unavailable and likely cause errors.") _danbooru_tag_classifier_df = pd.DataFrame() # Ensure it's an empty DataFrame to avoid further crashes def base64_to_pil_image(base64_data): try: compressed_data = base64.b64decode(base64_data) webp_data = gzip.decompress(compressed_data) image = Image.open(BytesIO(webp_data)).convert("RGBA") return image except Exception as e: print(f"Error decoding base64 image: {e}") return None def escape_parentheses_for_prompt(text: str) -> str: """Escapes parentheses for use in stable diffusion prompts.""" return text.replace('(', r'\(').replace(')', r'\)') def search_character_data_partial_for_app(character_name_partial, max_results): results = [] search_query_lower = character_name_partial.lower() if not search_query_lower: return [] count = 0 for lower_name, char_info in _character_md5_map.items(): if search_query_lower in lower_name: md5_hash = char_info['md5'] if md5_hash in _character_thumbs_data: pil_image = base64_to_pil_image(_character_thumbs_data[md5_hash]) if pil_image: results.append({ 'name': char_info['original_name'], 'image': pil_image, 'md5': md5_hash # Keep md5 for potential future use or debugging }) count += 1 if count >= max_results: break return results # --- WD14 Tagger Specific Functions --- def load_labels(dataframe) -> list[str]: name_series = dataframe["name"] name_series = name_series.map( lambda x: x.replace("_", " ") if x not in kaomojis else x ) tag_names = name_series.tolist() rating_indexes = list(np.where(dataframe["category"] == 9)[0]) general_indexes = list(np.where(dataframe["category"] == 0)[0]) character_indexes = list(np.where(dataframe["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def mcut_threshold(probs): sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh class Predictor: def __init__(self, model_repo_default: str): # Initialize with a default model repo self.model_target_size = None self.last_loaded_repo = None self.model = None self.tag_names = None self.rating_indexes = None self.general_indexes = None self.character_indexes = None self.load_model(model_repo_default) # Load the default model at init def download_model_files(self, model_repo): csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) return csv_path, model_path def load_model(self, model_repo): if model_repo == self.last_loaded_repo and self.model is not None: return csv_path, model_path = self.download_model_files(model_repo) tags_df = pd.read_csv(csv_path) self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = load_labels(tags_df) # Clear previous model if any if hasattr(self, 'model') and self.model is not None: del self.model self.model = rt.InferenceSession(model_path) _, height, width, _ = self.model.get_inputs()[0].shape self.model_target_size = height print(f"WD14 model loaded successfully from {model_repo}") self.last_loaded_repo = model_repo def prepare_image(self, image: Image.Image): # Explicitly type-hint as PIL.Image.Image """ Prepares a PIL Image for the WD14 tagger model. Resizes and converts the image to the model's expected input format. """ if not isinstance(image, Image.Image): raise ValueError("Input to prepare_image must be a PIL.Image.Image object.") target_size = self.model_target_size # Convert to RGBA first to handle potential alpha channel correctly for white background image = image.convert("RGBA") canvas = Image.new("RGBA", image.size, (255, 255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") # Now convert to RGB after blending with white background image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize( (target_size, target_size), Image.Resampling.BICUBIC, # Use Image.Resampling ) image_array = np.asarray(padded_image, dtype=np.float32) # The model expects BGR, so reverse the channels image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) # Add batch dimension def predict( self, image: Image.Image, model_repo: str, general_thresh: float, general_mcut_enabled: bool, character_thresh: float, character_mcut_enabled: bool, ): # Load model if it's different from the currently loaded one self.load_model(model_repo) # This will skip if already loaded image_input = self.prepare_image(image) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name preds = self.model.run([label_name], {input_name: image_input})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) general_names = [labels[i] for i in self.general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) if len(general_probs) > 1: general_thresh = mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_thresh] general_res_dict = dict(general_res) character_names = [labels[i] for i in self.character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) if len(character_probs) > 1: character_thresh = mcut_threshold(character_probs) character_thresh = max(0.15, character_thresh) character_res = [x for x in character_names if x[1] > character_thresh] character_res_dict = dict(character_res) # print(general_res_dict) # Prepare tag names for prompt (without confidence, no escaping yet) # general_tag_names = [x[0].replace('_', ' ') for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True)] general_tag_names = [ x[0] if x[0] in kaomojis else x[0].replace('_', ' ') for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True) ] character_tag_names = [x[0].replace('_', ' ') for x in sorted(character_res_dict.items(), key=lambda x: x[1], reverse=True)] return general_tag_names, character_tag_names # --- Gradio Interface Functions --- # Store search results temporarily for dropdown/gallery sync _last_search_results = [] def search_characters_gradio(character_name_partial, max_results): global _last_search_results load_character_data_for_app() # Ensure data is loaded found_chars = search_character_data_partial_for_app(character_name_partial, max_results) _last_search_results = found_chars # Store results if not found_chars: # Clear selected character states when no results are found return ( gr.Gallery(value=[], selected_index=None), # Clear gallery gr.Dropdown(choices=[], interactive=False, label="Select Character", value=None), # Clear dropdown value "", # Clear selected_char_original_name None, # Clear selected_pil_image_from_search "", # Clear selected_char_md5_hash "No characters found." # Message ) gallery_images_with_names = [] dropdown_options = [] for i, char_result in enumerate(found_chars): gallery_images_with_names.append((char_result['image'], char_result['name'])) dropdown_options.append((char_result['name'], i)) # Automatically select the first character if results are found first_char_name = found_chars[0]['name'] first_char_image = found_chars[0]['image'] first_char_md5 = found_chars[0]['md5'] return ( gr.Gallery(value=gallery_images_with_names, selected_index=0), gr.Dropdown(choices=dropdown_options, interactive=True, label="Select Character", value=0), first_char_name, first_char_image, first_char_md5, "" # Clear status message ) def get_selected_character_info_by_index(selected_index): # This function retrieves info based on the index from the dropdown or gallery if _last_search_results and selected_index is not None and 0 <= selected_index < len(_last_search_results): selected_char = _last_search_results[selected_index] return ( selected_char['name'], selected_char['image'], selected_char['md5'], # Ensure gallery selection is visually updated gr.Gallery(value=[(char['image'], char['name']) for char in _last_search_results], selected_index=selected_index) ) # If no valid selection, clear everything return "", None, "", gr.Gallery(value=[], selected_index=None) def update_dropdown_from_gallery(evt: gr.SelectData): # This function is called when an image in the gallery is clicked # evt.index gives the index of the clicked image if evt.index is not None: return evt.index return None # Return None to deselect dropdown if an empty area is clicked (though gallery itself might prevent this) def update_allowed_tags_set(selected_categories: list): """ Pre-processes the allowed tags set based on selected categories. This function is called when the tag category filter changes. """ global _preprocessed_allowed_tags_set, _last_selected_tag_categories if _danbooru_tag_classifier_df.empty or not selected_categories: _preprocessed_allowed_tags_set = set() _last_selected_tag_categories = selected_categories return # Check if categories have actually changed if selected_categories == _last_selected_tag_categories: return # print(f"Updating allowed tags set for categories: {selected_categories}") allowed_tags_set = set() for category in selected_categories: if category in _danbooru_tag_classifier_df.columns: # Get tags where the category column has a '1' # Convert to lower case for case-insensitive matching tags_in_category = _danbooru_tag_classifier_df[_danbooru_tag_classifier_df[category] == 1].index.tolist() # allowed_tags_set.update(tag.replace('_', ' ').lower() for tag in tags_in_category) allowed_tags_set.update( tag.lower() if tag in kaomojis else tag.replace('_', ' ').lower() for tag in tags_in_category ) _preprocessed_allowed_tags_set = allowed_tags_set _last_selected_tag_categories = selected_categories # print(f"Allowed tags set updated with {len(_preprocessed_allowed_tags_set)} tags.") # Unified prompt generation logic, now taking image and char_name explicitly def _generate_prompt_logic( image_to_tag: Image.Image, char_name_from_source: str, # This is the char name explicitly chosen from DB search, if applicable wd14_model_choice: str, general_thresh: float, general_mcut_enabled: bool, character_thresh: float, character_mcut_enabled: bool, banned_words_str: str, selected_tag_categories: list, preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags ): if image_to_tag is None: return "Error: No image provided for tagging.", "No prompt generated." global _wd14_predictor_instance if _wd14_predictor_instance is None: # Initialize with the currently selected model from the dropdown _wd14_predictor_instance = Predictor(WD14_MODELS[wd14_model_choice]) try: general_tags, character_tags = _wd14_predictor_instance.predict( image=image_to_tag, model_repo=WD14_MODELS[wd14_model_choice], # Pass the selected model repo general_thresh=general_thresh, general_mcut_enabled=general_mcut_enabled, character_thresh=character_thresh, character_mcut_enabled=character_mcut_enabled, ) except Exception as e: return f"Error during tag generation: {e}", "Error generating prompt." # Process banned words banned_words_list = [word.strip().lower() for word in banned_words_str.split(',') if word.strip()] prompt_parts = [] tags_to_filter = [] # Tags that will be subject to category filtering # 1. Handle character name from database search (if applicable) if char_name_from_source: char_name_lower = char_name_from_source.lower() if char_name_lower not in banned_words_list: if preserve_char_name_and_tags: prompt_parts.append(char_name_from_source) else: # If not preserving, it will be treated like any other tag for filtering tags_to_filter.append(char_name_from_source) # 2. Handle WD14 detected character tags for tag in character_tags: tag_lower = tag.lower() if tag_lower not in banned_words_list: # If preserving, add to prompt_parts directly if preserve_char_name_and_tags: # Ensure it's not a duplicate of char_name_from_source if that was already added if not (char_name_from_source and tag_lower == char_name_from_source.lower()): prompt_parts.append(tag) else: tags_to_filter.append(tag) # 3. Handle WD14 detected general tags for tag in general_tags: tag_lower = tag.lower() if tag_lower not in banned_words_list: tags_to_filter.append(tag) # --- TAG CLASSIFICATION FILTERING for tags_to_filter --- filtered_categorized_tags = [] # Ensure allowed tags set is up-to-date with current category selection update_allowed_tags_set(selected_tag_categories) if _preprocessed_allowed_tags_set: # If there are active category filters for tag in tags_to_filter: if tag.lower() in _preprocessed_allowed_tags_set: filtered_categorized_tags.append(tag) else: # If no categories are selected or classifier data is not active, use all raw tags filtered_categorized_tags = tags_to_filter # --- END TAG CLASSIFICATION FILTERING --- prompt_parts.extend(filtered_categorized_tags) # Ensure uniqueness and order (optional, but good practice for prompts) final_prompt_list = [] seen_lower = set() # Use a set of lowercased tags for uniqueness check for item in prompt_parts: item_lower = item.lower() if item_lower not in seen_lower: final_prompt_list.append(item) seen_lower.add(item_lower) final_prompt = ", ".join(final_prompt_list) final_prompt_escaped = escape_parentheses_for_prompt(final_prompt) return "Prompt generated successfully!", final_prompt_escaped # Wrapper function for the "Search Character Database" tab def generate_prompt_from_search_tab( character_original_name: str, selected_pil_image_from_search: Image.Image, wd14_model_choice: str, # Added model choice general_thresh: float, general_mcut_enabled: bool, character_thresh: float, character_mcut_enabled: bool, banned_words_str: str, selected_tag_categories: list, preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags ): # For the search tab, the image to tag is the selected character image image_to_tag = selected_pil_image_from_search # The character name comes from the search selection char_name_from_source = character_original_name return _generate_prompt_logic( image_to_tag=image_to_tag, char_name_from_source=char_name_from_source, wd14_model_choice=wd14_model_choice, # Pass model choice general_thresh=general_thresh, general_mcut_enabled=general_mcut_enabled, character_thresh=character_thresh, character_mcut_enabled=character_mcut_enabled, banned_words_str=banned_words_str, selected_tag_categories=selected_tag_categories, preserve_char_name_and_tags=preserve_char_name_and_tags, ) # Wrapper function for the "Upload Your Own Image" tab def generate_prompt_from_upload_tab( input_image_upload: Image.Image, wd14_model_choice: str, # Added model choice general_thresh: float, general_mcut_enabled: bool, character_thresh: float, character_mcut_enabled: bool, banned_words_str: str, selected_tag_categories: list, preserve_char_name_and_tags: bool, # Still relevant for detected char tags ): # For the upload tab, the image to tag is the uploaded image image_to_tag = input_image_upload # There is no explicit character name from a database search for the upload tab char_name_from_source = "" return _generate_prompt_logic( image_to_tag=image_to_tag, char_name_from_source=char_name_from_source, wd14_model_choice=wd14_model_choice, # Pass model choice general_thresh=general_thresh, general_mcut_enabled=general_mcut_enabled, character_thresh=character_thresh, character_mcut_enabled=character_mcut_enabled, banned_words_str=banned_words_str, selected_tag_categories=selected_tag_categories, preserve_char_name_and_tags=preserve_char_name_and_tags, ) # Initialize the predictor once globally with the default model _wd14_predictor_instance = Predictor(_wd14_selected_model_repo) # --- Gradio Interface Layout --- # Load data once at the start of the script load_character_data_for_app() # Get tag classification categories for the dropdown, if available tag_classification_categories = [] if _danbooru_tag_classifier_df is not None and not _danbooru_tag_classifier_df.empty: tag_classification_categories = [col for col in _danbooru_tag_classifier_df.columns if col not in ['tag_id']] else: print("tag classifier data not loaded or empty. Tag category filtering will be disabled.") CSS = """ #gallery { height: 300px; max-height: 520px; margin-left: 0; /* no left margin */ flex-grow: 1; } """ with gr.Blocks(title="Character Prompt Generator", css=CSS) as demo: gr.Markdown( """ # Character Prompt Generator Generate prompts for character images either by searching the database or uploading your own! """ ) # --- Shared Tagger Settings (Moved to the top for clarity) --- gr.Markdown("#### ⚙️ Tagger Settings (Apply to both methods)") with gr.Accordion("Adjust Tagging Parameters", open=False): # NEW: Model Selection Dropdown wd14_model_dropdown = gr.Dropdown( label="WD14 Tagger Model", choices=list(WD14_MODELS.keys()), value=list(WD14_MODELS.keys())[0], interactive=True, info="Select which WD14 tagger model to use. Different models may yield slightly different tags." ) with gr.Row(): general_threshold_slider = gr.Slider( minimum=0.01, maximum=0.99, value=_wd14_general_threshold, step=0.01, label="General Tags Threshold", interactive=True, info="Confidence score required for general tags to be included." ) general_mcut_checkbox = gr.Checkbox( value=_wd14_general_mcut_enabled, label="Use MCut for General Tags", interactive=True, info="MCut dynamically sets the threshold based on tag score distribution." ) with gr.Row(): character_threshold_slider = gr.Slider( minimum=0.01, maximum=0.99, value=_wd14_character_threshold, step=0.01, label="Character Tags Threshold", interactive=True, info="Confidence score required for character tags to be included." ) character_mcut_checkbox = gr.Checkbox( value=_wd14_character_mcut_enabled, label="Use MCut for Character Tags", interactive=True, info="MCut dynamically sets the threshold based on tag score distribution." ) banned_words_input = gr.Textbox( label="Exclude Tags (comma-separated, case-insensitive)", value="simple background, lowres, text", placeholder="e.g., text, watermark", interactive=True, info="Tags listed here will be removed from the final prompt." ) tag_category_filter = gr.Dropdown( label="Filter General Tags by Category", choices=tag_classification_categories, multiselect=True, interactive=bool(tag_classification_categories), value=["body", "hair", "eyes_face", "species_traits", "head_accessories", "subjects_relationship"] if tag_classification_categories else [], info="Only general tags belonging to selected categories will be included. Character tags are not affected by this filter unless 'Preserve Characters' is unchecked." ) preserve_char_name_and_tags_checkbox = gr.Checkbox( label="Preserve Character-Specific Tags (Bypass Category Filter)", value=True, interactive=True, info="If checked, the selected character's name (from search) and any detected character tags will *always* be included (unless banned), bypassing the 'Filter Tags by Category' setting. General tags are still filtered." ) # Hidden states to pass character data to prompt generation (keep these global-ish) selected_pil_image_from_search = gr.State(None) selected_char_original_name = gr.State("") selected_char_md5_hash = gr.State("") with gr.Tabs(): with gr.TabItem("1. Search Character Database"): gr.Markdown("### 🔍 Find a Character") with gr.Row(): with gr.Column(scale=1): search_input = gr.Textbox( label="Enter Character Name", placeholder="e.g., Hatsune, zhongli", interactive=True ) max_results_slider = gr.Slider( minimum=1, maximum=50, value=DEFAULT_MAX_DISPLAY_RESULTS, step=1, label="Max Search Results to Display", interactive=True ) search_output_message = gr.Markdown("Enter a character name to search.") with gr.Column(scale=2): character_gallery = gr.Gallery( label="Matching Characters (Click to select)", show_label=True, elem_id="gallery", columns=3, rows=15, object_fit="contain", # height="auto", allow_preview=False, ) gr.Markdown("---") gr.Markdown("### ➡️ Selected Character") selected_character_dropdown = gr.Dropdown( label="Selected Character", choices=[], interactive=False, info="The character chosen from the search results above." ) generate_prompt_button_search = gr.Button("Generate Prompt for Selected Character", variant="primary") with gr.TabItem("2. Upload Your Own Image"): gr.Markdown("### ⬆️ Upload an Image to Tag") input_image_upload = gr.Image( label="Upload Image (JPG, PNG, WEBP)", type="pil", height=200, interactive=True, # info="This image will be sent to the tagger for prompt generation." ) generate_prompt_button_upload = gr.Button("Generate Prompt for Uploaded Image", variant="primary") gr.Markdown("---") gr.Markdown("#### ✨ Generated Prompt") # Outputs are now consolidated here, after the inputs and buttons generate_prompt_status = gr.Markdown( "", elem_id="shared_prompt_status" ) prompt_output = gr.Textbox( label="", placeholder="Your generated prompt will appear here...", lines=5, interactive=True, show_copy_button=True, elem_id="shared_prompt_output" ) # --- Event Handlers --- # Search Character Database Tab Interactions search_input.change( fn=search_characters_gradio, inputs=[search_input, max_results_slider], outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message] ) max_results_slider.change( fn=search_characters_gradio, inputs=[search_input, max_results_slider], outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message] ) selected_character_dropdown.change( fn=get_selected_character_info_by_index, inputs=[selected_character_dropdown], outputs=[selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, character_gallery] ) character_gallery.select( fn=update_dropdown_from_gallery, inputs=None, outputs=[selected_character_dropdown] ) # Tag Category Filter (Shared Setting) tag_category_filter.change( fn=update_allowed_tags_set, inputs=[tag_category_filter], outputs=None, # This function only updates a global, doesn't return a Gradio component api_name=False # Don't expose this helper function via API ) # Also call it once at startup with the initial value to populate the global set demo.load( fn=lambda: update_allowed_tags_set(tag_category_filter.value), inputs=[], outputs=[] ) # Generate Prompt Button for "Search Character Database" tab generate_prompt_button_search.click( fn=generate_prompt_from_search_tab, inputs=[ selected_char_original_name, selected_pil_image_from_search, wd14_model_dropdown, general_threshold_slider, general_mcut_checkbox, character_threshold_slider, character_mcut_checkbox, banned_words_input, tag_category_filter, preserve_char_name_and_tags_checkbox ], outputs=[generate_prompt_status, prompt_output] ) # Generate Prompt Button for "Upload Your Own Image" tab generate_prompt_button_upload.click( fn=generate_prompt_from_upload_tab, inputs=[ input_image_upload, wd14_model_dropdown, general_threshold_slider, general_mcut_checkbox, character_threshold_slider, character_mcut_checkbox, banned_words_input, tag_category_filter, preserve_char_name_and_tags_checkbox ], outputs=[generate_prompt_status, prompt_output] ) demo.launch(debug=True, show_error=True)