""" Processor class for ILLUME_plus with dualvitok and dualvitok-sdxl-decoder. """ import json from typing import List, Union from transformers import AutoProcessor, AutoImageProcessor try: from typing import Unpack except ImportError: from typing_extensions import Unpack from transformers.feature_extraction_utils import BatchFeature from .image_utils import ImageInput from transformers.processing_utils import ( ProcessingKwargs, ProcessorMixin, ) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from PIL import Image import re # Added for parsing image tokens from typing import List, Tuple import torch from .configuration_illume import ILLUMEConfig from .image_processing_movqgan import MoVQImageProcessor from .image_processing_dualvitok import DualViTokImageProcessor from .aspect_ratio_utils import AspectRatioCrop, RATIOS, unpad_and_resize_back from .inference_utils import parse_interleaved_text_image, calculate_image_token_num from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline logger = logging.get_logger(__name__) class ILLUMEProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, }, } class ILLUMEProcessor(ProcessorMixin): r""" Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor. [`ILLUMEProcessor`] offers all the functionalities of [`ILLUMEImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~ILLUMEProcessor.__call__`] and [`~ILLUMEProcessor.decode`] for more information. Args: image_processor ([`IllumeImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`Qwen2TokenizerFast`], *optional*): The tokenizer is a required input. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") _re_placeholder = re.compile(r"|") _default_generation_template = "Generate an image of {resolution_tag}, the content of image is {content}\n" _default_generation_unconditional_template = "Generate a random image of {resolution_tag}\n" _default_editing_template = "{resolution_tag}\nPlease edit the image according to the instruction: {content}\n" _default_editing_unconditional_template = "{resolution_tag}\nReconstruct the image according to the given image\n" def __init__(self, image_processor=None, tokenizer=None, chat_template=None, crop_percent_thresh=0.2, **kwargs): super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template) self.vision_tokenizer = None self.diffusion_vision_detokenizer = None self.crop_percent_thresh = crop_percent_thresh @property def default_generation_template(self): return self._default_generation_template @property def default_generation_unconditional_template(self): return self._default_generation_unconditional_template @property def default_editing_template(self): return self._default_editing_template @property def default_editing_unconditional_template(self): return self._default_editing_unconditional_template def get_resolution_tag_from_resolution(self, resolution): return f"" def set_vision_tokenizer(self, tokenizer): if self.vision_tokenizer and tokenizer: logger.info('You are resetting vision tokenizer!') return self.vision_tokenizer = tokenizer logger.info('Setting vision tokenizer!') def load_diffusion_vision_detokenizer(self, diffusion_decoder, torch_dtype=torch.float16, add_watermarker=False, device='cuda', ): if self.diffusion_vision_detokenizer: logger.info('You are resetting diffusion vision detokenizer!') return if self.vision_tokenizer is None: raise ValueError("Vision tokenizer is not set. Please set the vision tokenizer by using `processor.set_vision_tokenizer`") self.diffusion_vision_detokenizer = StableDiffusionXLDecoderPipeline.from_pretrained(diffusion_decoder, torch_dtype=torch_dtype, add_watermarker=add_watermarker, vq_model=self.vision_tokenizer, vq_image_processor=self.image_processor).to(device) logger.info('Setting diffusion vision detokenizer!') def get_ratio_tag_from_ratio(self, ratio): h, w = ratio ratio_tag = f"" return ratio_tag @torch.no_grad() def _encode_with_dualvitok(self, img): # img is a PIL image or np.ndarray px = self.image_processor(img, return_tensors="pt")["pixel_values"].to(self.vision_tokenizer.device) (_, _, idx_sem, _), (_, _, idx_pix) = self.vision_tokenizer.encode(px) return idx_sem[0].cpu().tolist(), idx_pix[0].cpu().tolist() def transform_image_nearest_resolution_ratio(self, image, ratios=RATIOS): arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh) image, original_size, target_size, flag_matched = arc(image, is_inference=True) return image def convert_image_to_token_string(self, image, ratios=RATIOS): arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh) image, original_size, target_size, flag_matched = arc(image, is_inference=True) ratio_tag = self.get_ratio_tag_from_ratio(target_size) image_embed_inds = self._encode_with_dualvitok(image) return ratio_tag + self.encode_image_token_into_code(image_embed_inds) def unpad_and_resize_back(self, padded_image, original_width, original_height): return unpad_and_resize_back(padded_image, original_width, original_height) def encode_image_token_into_code(self, image_embed_inds, add_token_name="<|image_level{}_{}|>", selected_vision_tokenizer_levels=None): ''' Args: image_embed_inds: 3D list, vision token ids for each tokenizer level add_token_name: tag name for vision tokens Returns: image_token_return: str ''' if selected_vision_tokenizer_levels is not None: image_embed_inds_new = [] for level in selected_vision_tokenizer_levels: image_embed_inds_new.append(image_embed_inds[level]) image_embed_inds = image_embed_inds_new image_token_name_list = [] for level, image_embed_ind in enumerate(image_embed_inds): image_token_name = [] for row in image_embed_ind: image_token_name.append([add_token_name.format(level, ind) for ind in row]) image_token_name_list.append("".format(level)) for row in image_token_name: row.append("") for row in image_token_name: image_token_name_list.extend(row) image_token_name_list.append("".format(level)) image_token_return = "".join(image_token_name_list) image_token_return = "" + image_token_return + "" return image_token_return @torch.no_grad() def decode_images(self, image_inds_list, target_resolution=(512, 512), return_type='pil', use_diffusion=False, diffusion_cfg_scale=2.0, diffusion_num_inference_steps=20, **kwargs): token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*target_resolution) decoded_images = [] for image_inds in image_inds_list: semantic_code = torch.as_tensor([image_inds[0]]) texture_code = torch.as_tensor([image_inds[1]]) if use_diffusion: if self.diffusion_vision_detokenizer is None: raise RuntimeError( "diffusion_vision_detokenizer is not set. Please set the diffusion decoder by using `pipe.load_diffusion_vision_detokenizer`") semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1) texture_code = texture_code.view(texture_code.shape[0], h2, w2) diffusion_outputs = self.diffusion_vision_detokenizer( vq_indices=(semantic_code, texture_code), height=target_resolution[0] * 2, width=target_resolution[1] * 2, guidance_scale=diffusion_cfg_scale, num_inference_steps=diffusion_num_inference_steps, output_type=return_type, **kwargs ) samples = diffusion_outputs.images image = samples[0] else: if self.vision_tokenizer is None: raise RuntimeError( "vision_detokenizer is not set. Please set the vision decoder by using `pipe.set_vision_detokenizer`") semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1) texture_code = texture_code.view(texture_code.shape[0], h2, w2) samples = self.vision_tokenizer.decode_code(semantic_code, texture_code) if return_type == 'pil': sample = \ torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).cpu().to(torch.uint8).numpy()[0] image = Image.fromarray(sample) else: # return numpy range -1 to 1. image = samples.permute(0, 2, 3, 1).cpu().numpy()[0] decoded_images.append(image) return decoded_images def parse_text_image(self, text, image_placeholder=''): generated_text, image_embed_inds_list, list_image_token_parts = parse_interleaved_text_image(text, num_levels=2, image_placeholder=image_placeholder) return generated_text, image_embed_inds_list, list_image_token_parts def _encode_out_placeholder(self, img): """ Encode one image with DualViTok and return a string that can replace the marker in the text. """ return self.convert_image_to_token_string(img) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, **kwargs: Unpack[ILLUMEProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to DualViTokImageProcessor's [`~DualViTokImageProcessor.__call__`] if `vision_infos` is not `None`. Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ output_kwargs = self._merge_kwargs( ILLUMEProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if not isinstance(text, list): text = [text] if isinstance(images, str): images = [images] elif images and isinstance(images[0], list): # flatten List[List[PIL.Image.Image]] images = [item for sublist in images for item in sublist] _ = output_kwargs["text_kwargs"].pop("padding_side", None) try: text = self.apply_chat_template(text, add_generation_prompt=True, padding=True) except Exception as e: logger.info('Warning: input texts have been applied chat templates!') if images is None: text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs}) else: imgs_in, new_text, used = [], [], 0 if not isinstance(text, list): text = [text] for s in text: # walk each prompt out, i = [], 0 for m in self._re_placeholder.finditer(s): # every placeholder out.append(s[i:m.start()]) if used >= len(images): raise ValueError("not enough images for placeholders") img = images[used] used += 1 if m.group() == "": out.append(self.convert_image_to_token_string(img)) # replace else: # out.append("") imgs_in.append(img) # keep for pixel feats i = m.end() out.append(s[i:]) new_text.append("".join(out)) if used != len(images): raise ValueError(f"too many images for placeholders. used {used} vs len(images) {len(images)}. {text}") text_inputs = self.tokenizer(new_text, **output_kwargs["text_kwargs"]) image_inputs = self.image_processor.preprocess(imgs_in, **output_kwargs["images_kwargs"]) if imgs_in else {} return BatchFeature(data={**text_inputs, **image_inputs}) def batch_decode(self, sequences, *args, **kwargs): return [self.decode(seq, *args, **kwargs) for i, seq in enumerate(sequences)] def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))