illume_plus-qwen2_5-7b-hf / processing_illume.py
huangrh9's picture
Upload folder using huggingface_hub
e853ca4 verified
"""
Processor class for ILLUME_plus with dualvitok and dualvitok-sdxl-decoder.
"""
import json
from typing import List, Union
from transformers import AutoProcessor, AutoImageProcessor
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
from transformers.feature_extraction_utils import BatchFeature
from .image_utils import ImageInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from PIL import Image
import re # Added for parsing image tokens
from typing import List, Tuple
import torch
from .configuration_illume import ILLUMEConfig
from .image_processing_movqgan import MoVQImageProcessor
from .image_processing_dualvitok import DualViTokImageProcessor
from .aspect_ratio_utils import AspectRatioCrop, RATIOS, unpad_and_resize_back
from .inference_utils import parse_interleaved_text_image, calculate_image_token_num
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
logger = logging.get_logger(__name__)
class ILLUMEProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class ILLUMEProcessor(ProcessorMixin):
r"""
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
[`ILLUMEProcessor`] offers all the functionalities of [`ILLUMEImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~ILLUMEProcessor.__call__`] and [`~ILLUMEProcessor.decode`] for more information.
Args:
image_processor ([`IllumeImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
_re_placeholder = re.compile(r"<image_out>|<image>")
_default_generation_template = "Generate an image of {resolution_tag}, the content of image is {content}\n"
_default_generation_unconditional_template = "Generate a random image of {resolution_tag}\n"
_default_editing_template = "{resolution_tag}<image>\nPlease edit the image according to the instruction: {content}\n"
_default_editing_unconditional_template = "{resolution_tag}<image>\nReconstruct the image according to the given image\n"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None,
crop_percent_thresh=0.2, **kwargs):
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
self.vision_tokenizer = None
self.diffusion_vision_detokenizer = None
self.crop_percent_thresh = crop_percent_thresh
@property
def default_generation_template(self):
return self._default_generation_template
@property
def default_generation_unconditional_template(self):
return self._default_generation_unconditional_template
@property
def default_editing_template(self):
return self._default_editing_template
@property
def default_editing_unconditional_template(self):
return self._default_editing_unconditional_template
def get_resolution_tag_from_resolution(self, resolution):
return f"<height_{resolution[0]}><width_{resolution[1]}>"
def set_vision_tokenizer(self, tokenizer):
if self.vision_tokenizer and tokenizer:
logger.info('You are resetting vision tokenizer!')
return
self.vision_tokenizer = tokenizer
logger.info('Setting vision tokenizer!')
def load_diffusion_vision_detokenizer(self, diffusion_decoder,
torch_dtype=torch.float16,
add_watermarker=False,
device='cuda',
):
if self.diffusion_vision_detokenizer:
logger.info('You are resetting diffusion vision detokenizer!')
return
if self.vision_tokenizer is None:
raise ValueError("Vision tokenizer is not set. Please set the vision tokenizer by using `processor.set_vision_tokenizer`")
self.diffusion_vision_detokenizer = StableDiffusionXLDecoderPipeline.from_pretrained(diffusion_decoder,
torch_dtype=torch_dtype,
add_watermarker=add_watermarker,
vq_model=self.vision_tokenizer,
vq_image_processor=self.image_processor).to(device)
logger.info('Setting diffusion vision detokenizer!')
def get_ratio_tag_from_ratio(self, ratio):
h, w = ratio
ratio_tag = f"<height_{h}><width_{w}>"
return ratio_tag
@torch.no_grad()
def _encode_with_dualvitok(self, img):
# img is a PIL image or np.ndarray
px = self.image_processor(img, return_tensors="pt")["pixel_values"].to(self.vision_tokenizer.device)
(_, _, idx_sem, _), (_, _, idx_pix) = self.vision_tokenizer.encode(px)
return idx_sem[0].cpu().tolist(), idx_pix[0].cpu().tolist()
def transform_image_nearest_resolution_ratio(self, image, ratios=RATIOS):
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
return image
def convert_image_to_token_string(self, image, ratios=RATIOS):
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
ratio_tag = self.get_ratio_tag_from_ratio(target_size)
image_embed_inds = self._encode_with_dualvitok(image)
return ratio_tag + self.encode_image_token_into_code(image_embed_inds)
def unpad_and_resize_back(self, padded_image, original_width, original_height):
return unpad_and_resize_back(padded_image, original_width, original_height)
def encode_image_token_into_code(self, image_embed_inds,
add_token_name="<|image_level{}_{}|>",
selected_vision_tokenizer_levels=None):
'''
Args:
image_embed_inds: 3D list, vision token ids for each tokenizer level
add_token_name: tag name for vision tokens
Returns:
image_token_return: str
'''
if selected_vision_tokenizer_levels is not None:
image_embed_inds_new = []
for level in selected_vision_tokenizer_levels:
image_embed_inds_new.append(image_embed_inds[level])
image_embed_inds = image_embed_inds_new
image_token_name_list = []
for level, image_embed_ind in enumerate(image_embed_inds):
image_token_name = []
for row in image_embed_ind:
image_token_name.append([add_token_name.format(level, ind) for ind in row])
image_token_name_list.append("<start_of_level{}>".format(level))
for row in image_token_name:
row.append("<end_of_line>")
for row in image_token_name:
image_token_name_list.extend(row)
image_token_name_list.append("<end_of_level{}>".format(level))
image_token_return = "".join(image_token_name_list)
image_token_return = "<start_of_image>" + image_token_return + "<end_of_image>"
return image_token_return
@torch.no_grad()
def decode_images(self, image_inds_list, target_resolution=(512, 512), return_type='pil',
use_diffusion=False, diffusion_cfg_scale=2.0, diffusion_num_inference_steps=20, **kwargs):
token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*target_resolution)
decoded_images = []
for image_inds in image_inds_list:
semantic_code = torch.as_tensor([image_inds[0]])
texture_code = torch.as_tensor([image_inds[1]])
if use_diffusion:
if self.diffusion_vision_detokenizer is None:
raise RuntimeError(
"diffusion_vision_detokenizer is not set. Please set the diffusion decoder by using `pipe.load_diffusion_vision_detokenizer`")
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
diffusion_outputs = self.diffusion_vision_detokenizer(
vq_indices=(semantic_code, texture_code),
height=target_resolution[0] * 2,
width=target_resolution[1] * 2,
guidance_scale=diffusion_cfg_scale,
num_inference_steps=diffusion_num_inference_steps,
output_type=return_type,
**kwargs
)
samples = diffusion_outputs.images
image = samples[0]
else:
if self.vision_tokenizer is None:
raise RuntimeError(
"vision_detokenizer is not set. Please set the vision decoder by using `pipe.set_vision_detokenizer`")
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
samples = self.vision_tokenizer.decode_code(semantic_code, texture_code)
if return_type == 'pil':
sample = \
torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).cpu().to(torch.uint8).numpy()[0]
image = Image.fromarray(sample)
else: # return numpy range -1 to 1.
image = samples.permute(0, 2, 3, 1).cpu().numpy()[0]
decoded_images.append(image)
return decoded_images
def parse_text_image(self, text, image_placeholder='<image_out>'):
generated_text, image_embed_inds_list, list_image_token_parts = parse_interleaved_text_image(text, num_levels=2,
image_placeholder=image_placeholder)
return generated_text, image_embed_inds_list, list_image_token_parts
def _encode_out_placeholder(self, img):
"""
Encode one image with DualViTok and return a string
that can replace the <image_out> marker in the text.
"""
return self.convert_image_to_token_string(img)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
**kwargs: Unpack[ILLUMEProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
DualViTokImageProcessor's [`~DualViTokImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ILLUMEProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if not isinstance(text, list):
text = [text]
if isinstance(images, str):
images = [images]
elif images and isinstance(images[0], list):
# flatten List[List[PIL.Image.Image]]
images = [item for sublist in images for item in sublist]
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
try:
text = self.apply_chat_template(text, add_generation_prompt=True, padding=True)
except Exception as e:
logger.info('Warning: input texts have been applied chat templates!')
if images is None:
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs})
else:
imgs_in, new_text, used = [], [], 0
if not isinstance(text, list):
text = [text]
for s in text: # walk each prompt
out, i = [], 0
for m in self._re_placeholder.finditer(s): # every placeholder
out.append(s[i:m.start()])
if used >= len(images):
raise ValueError("not enough images for placeholders")
img = images[used]
used += 1
if m.group() == "<image_out>":
out.append(self.convert_image_to_token_string(img)) # replace
else: # <image>
out.append("<image>")
imgs_in.append(img) # keep for pixel feats
i = m.end()
out.append(s[i:])
new_text.append("".join(out))
if used != len(images):
raise ValueError(f"too many images for placeholders. used {used} vs len(images) {len(images)}. {text}")
text_inputs = self.tokenizer(new_text, **output_kwargs["text_kwargs"])
image_inputs = self.image_processor.preprocess(imgs_in, **output_kwargs["images_kwargs"]) if imgs_in else {}
return BatchFeature(data={**text_inputs, **image_inputs})
def batch_decode(self, sequences, *args, **kwargs):
return [self.decode(seq, *args, **kwargs)
for i, seq in enumerate(sequences)]
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))