|
|
"""
|
|
|
Processor class for ILLUME_plus with dualvitok and dualvitok-sdxl-decoder.
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
from typing import List, Union
|
|
|
|
|
|
from transformers import AutoProcessor, AutoImageProcessor
|
|
|
|
|
|
try:
|
|
|
from typing import Unpack
|
|
|
except ImportError:
|
|
|
from typing_extensions import Unpack
|
|
|
|
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
|
from .image_utils import ImageInput
|
|
|
from transformers.processing_utils import (
|
|
|
ProcessingKwargs,
|
|
|
ProcessorMixin,
|
|
|
)
|
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
from PIL import Image
|
|
|
import re
|
|
|
from typing import List, Tuple
|
|
|
import torch
|
|
|
|
|
|
from .configuration_illume import ILLUMEConfig
|
|
|
from .image_processing_movqgan import MoVQImageProcessor
|
|
|
from .image_processing_dualvitok import DualViTokImageProcessor
|
|
|
from .aspect_ratio_utils import AspectRatioCrop, RATIOS, unpad_and_resize_back
|
|
|
from .inference_utils import parse_interleaved_text_image, calculate_image_token_num
|
|
|
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
class ILLUMEProcessorKwargs(ProcessingKwargs, total=False):
|
|
|
_defaults = {
|
|
|
"text_kwargs": {
|
|
|
"padding": False,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
class ILLUMEProcessor(ProcessorMixin):
|
|
|
r"""
|
|
|
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
|
|
|
[`ILLUMEProcessor`] offers all the functionalities of [`ILLUMEImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
|
|
[`~ILLUMEProcessor.__call__`] and [`~ILLUMEProcessor.decode`] for more information.
|
|
|
Args:
|
|
|
image_processor ([`IllumeImageProcessor`], *optional*):
|
|
|
The image processor is a required input.
|
|
|
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
|
|
The tokenizer is a required input.
|
|
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
|
|
in a chat into a tokenizable string.
|
|
|
"""
|
|
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
|
valid_kwargs = ["chat_template"]
|
|
|
image_processor_class = "AutoImageProcessor"
|
|
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
|
|
_re_placeholder = re.compile(r"<image_out>|<image>")
|
|
|
|
|
|
_default_generation_template = "Generate an image of {resolution_tag}, the content of image is {content}\n"
|
|
|
_default_generation_unconditional_template = "Generate a random image of {resolution_tag}\n"
|
|
|
|
|
|
_default_editing_template = "{resolution_tag}<image>\nPlease edit the image according to the instruction: {content}\n"
|
|
|
_default_editing_unconditional_template = "{resolution_tag}<image>\nReconstruct the image according to the given image\n"
|
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None,
|
|
|
crop_percent_thresh=0.2, **kwargs):
|
|
|
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
|
|
|
self.vision_tokenizer = None
|
|
|
self.diffusion_vision_detokenizer = None
|
|
|
self.crop_percent_thresh = crop_percent_thresh
|
|
|
|
|
|
@property
|
|
|
def default_generation_template(self):
|
|
|
return self._default_generation_template
|
|
|
|
|
|
@property
|
|
|
def default_generation_unconditional_template(self):
|
|
|
return self._default_generation_unconditional_template
|
|
|
|
|
|
@property
|
|
|
def default_editing_template(self):
|
|
|
return self._default_editing_template
|
|
|
|
|
|
@property
|
|
|
def default_editing_unconditional_template(self):
|
|
|
return self._default_editing_unconditional_template
|
|
|
|
|
|
def get_resolution_tag_from_resolution(self, resolution):
|
|
|
return f"<height_{resolution[0]}><width_{resolution[1]}>"
|
|
|
|
|
|
def set_vision_tokenizer(self, tokenizer):
|
|
|
if self.vision_tokenizer and tokenizer:
|
|
|
logger.info('You are resetting vision tokenizer!')
|
|
|
return
|
|
|
self.vision_tokenizer = tokenizer
|
|
|
logger.info('Setting vision tokenizer!')
|
|
|
|
|
|
def load_diffusion_vision_detokenizer(self, diffusion_decoder,
|
|
|
torch_dtype=torch.float16,
|
|
|
add_watermarker=False,
|
|
|
device='cuda',
|
|
|
):
|
|
|
if self.diffusion_vision_detokenizer:
|
|
|
logger.info('You are resetting diffusion vision detokenizer!')
|
|
|
return
|
|
|
|
|
|
if self.vision_tokenizer is None:
|
|
|
raise ValueError("Vision tokenizer is not set. Please set the vision tokenizer by using `processor.set_vision_tokenizer`")
|
|
|
|
|
|
self.diffusion_vision_detokenizer = StableDiffusionXLDecoderPipeline.from_pretrained(diffusion_decoder,
|
|
|
torch_dtype=torch_dtype,
|
|
|
add_watermarker=add_watermarker,
|
|
|
vq_model=self.vision_tokenizer,
|
|
|
vq_image_processor=self.image_processor).to(device)
|
|
|
logger.info('Setting diffusion vision detokenizer!')
|
|
|
|
|
|
def get_ratio_tag_from_ratio(self, ratio):
|
|
|
h, w = ratio
|
|
|
ratio_tag = f"<height_{h}><width_{w}>"
|
|
|
return ratio_tag
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _encode_with_dualvitok(self, img):
|
|
|
|
|
|
px = self.image_processor(img, return_tensors="pt")["pixel_values"].to(self.vision_tokenizer.device)
|
|
|
(_, _, idx_sem, _), (_, _, idx_pix) = self.vision_tokenizer.encode(px)
|
|
|
return idx_sem[0].cpu().tolist(), idx_pix[0].cpu().tolist()
|
|
|
|
|
|
def transform_image_nearest_resolution_ratio(self, image, ratios=RATIOS):
|
|
|
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
|
|
|
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
|
|
|
return image
|
|
|
|
|
|
def convert_image_to_token_string(self, image, ratios=RATIOS):
|
|
|
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
|
|
|
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
|
|
|
ratio_tag = self.get_ratio_tag_from_ratio(target_size)
|
|
|
|
|
|
image_embed_inds = self._encode_with_dualvitok(image)
|
|
|
return ratio_tag + self.encode_image_token_into_code(image_embed_inds)
|
|
|
|
|
|
def unpad_and_resize_back(self, padded_image, original_width, original_height):
|
|
|
return unpad_and_resize_back(padded_image, original_width, original_height)
|
|
|
|
|
|
def encode_image_token_into_code(self, image_embed_inds,
|
|
|
add_token_name="<|image_level{}_{}|>",
|
|
|
selected_vision_tokenizer_levels=None):
|
|
|
'''
|
|
|
Args:
|
|
|
image_embed_inds: 3D list, vision token ids for each tokenizer level
|
|
|
add_token_name: tag name for vision tokens
|
|
|
Returns:
|
|
|
image_token_return: str
|
|
|
'''
|
|
|
|
|
|
if selected_vision_tokenizer_levels is not None:
|
|
|
image_embed_inds_new = []
|
|
|
for level in selected_vision_tokenizer_levels:
|
|
|
image_embed_inds_new.append(image_embed_inds[level])
|
|
|
image_embed_inds = image_embed_inds_new
|
|
|
|
|
|
image_token_name_list = []
|
|
|
for level, image_embed_ind in enumerate(image_embed_inds):
|
|
|
image_token_name = []
|
|
|
for row in image_embed_ind:
|
|
|
image_token_name.append([add_token_name.format(level, ind) for ind in row])
|
|
|
|
|
|
image_token_name_list.append("<start_of_level{}>".format(level))
|
|
|
for row in image_token_name:
|
|
|
row.append("<end_of_line>")
|
|
|
|
|
|
for row in image_token_name:
|
|
|
image_token_name_list.extend(row)
|
|
|
|
|
|
image_token_name_list.append("<end_of_level{}>".format(level))
|
|
|
|
|
|
image_token_return = "".join(image_token_name_list)
|
|
|
image_token_return = "<start_of_image>" + image_token_return + "<end_of_image>"
|
|
|
return image_token_return
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def decode_images(self, image_inds_list, target_resolution=(512, 512), return_type='pil',
|
|
|
use_diffusion=False, diffusion_cfg_scale=2.0, diffusion_num_inference_steps=20, **kwargs):
|
|
|
|
|
|
token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*target_resolution)
|
|
|
|
|
|
decoded_images = []
|
|
|
for image_inds in image_inds_list:
|
|
|
semantic_code = torch.as_tensor([image_inds[0]])
|
|
|
texture_code = torch.as_tensor([image_inds[1]])
|
|
|
if use_diffusion:
|
|
|
if self.diffusion_vision_detokenizer is None:
|
|
|
raise RuntimeError(
|
|
|
"diffusion_vision_detokenizer is not set. Please set the diffusion decoder by using `pipe.load_diffusion_vision_detokenizer`")
|
|
|
|
|
|
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
|
|
|
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
|
|
|
|
|
|
diffusion_outputs = self.diffusion_vision_detokenizer(
|
|
|
vq_indices=(semantic_code, texture_code),
|
|
|
height=target_resolution[0] * 2,
|
|
|
width=target_resolution[1] * 2,
|
|
|
guidance_scale=diffusion_cfg_scale,
|
|
|
num_inference_steps=diffusion_num_inference_steps,
|
|
|
output_type=return_type,
|
|
|
**kwargs
|
|
|
)
|
|
|
samples = diffusion_outputs.images
|
|
|
image = samples[0]
|
|
|
else:
|
|
|
if self.vision_tokenizer is None:
|
|
|
raise RuntimeError(
|
|
|
"vision_detokenizer is not set. Please set the vision decoder by using `pipe.set_vision_detokenizer`")
|
|
|
|
|
|
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
|
|
|
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
|
|
|
|
|
|
samples = self.vision_tokenizer.decode_code(semantic_code, texture_code)
|
|
|
|
|
|
if return_type == 'pil':
|
|
|
sample = \
|
|
|
torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).cpu().to(torch.uint8).numpy()[0]
|
|
|
image = Image.fromarray(sample)
|
|
|
else:
|
|
|
image = samples.permute(0, 2, 3, 1).cpu().numpy()[0]
|
|
|
decoded_images.append(image)
|
|
|
|
|
|
return decoded_images
|
|
|
|
|
|
def parse_text_image(self, text, image_placeholder='<image_out>'):
|
|
|
generated_text, image_embed_inds_list, list_image_token_parts = parse_interleaved_text_image(text, num_levels=2,
|
|
|
image_placeholder=image_placeholder)
|
|
|
return generated_text, image_embed_inds_list, list_image_token_parts
|
|
|
|
|
|
def _encode_out_placeholder(self, img):
|
|
|
"""
|
|
|
Encode one image with DualViTok and return a string
|
|
|
that can replace the <image_out> marker in the text.
|
|
|
"""
|
|
|
return self.convert_image_to_token_string(img)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
|
images: ImageInput = None,
|
|
|
**kwargs: Unpack[ILLUMEProcessorKwargs],
|
|
|
) -> BatchFeature:
|
|
|
"""
|
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
|
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
|
|
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
|
|
DualViTokImageProcessor's [`~DualViTokImageProcessor.__call__`] if `vision_infos` is not `None`.
|
|
|
|
|
|
Args:
|
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
|
tensor. Both channels-first and channels-last formats are supported.
|
|
|
text (`str`, `List[str]`, `List[List[str]]`):
|
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
|
|
|
Returns:
|
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
|
`None`).
|
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
|
"""
|
|
|
output_kwargs = self._merge_kwargs(
|
|
|
ILLUMEProcessorKwargs,
|
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
if not isinstance(text, list):
|
|
|
text = [text]
|
|
|
|
|
|
if isinstance(images, str):
|
|
|
images = [images]
|
|
|
elif images and isinstance(images[0], list):
|
|
|
|
|
|
images = [item for sublist in images for item in sublist]
|
|
|
|
|
|
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
|
|
|
try:
|
|
|
text = self.apply_chat_template(text, add_generation_prompt=True, padding=True)
|
|
|
except Exception as e:
|
|
|
logger.info('Warning: input texts have been applied chat templates!')
|
|
|
|
|
|
if images is None:
|
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
|
return BatchFeature(data={**text_inputs})
|
|
|
else:
|
|
|
imgs_in, new_text, used = [], [], 0
|
|
|
|
|
|
if not isinstance(text, list):
|
|
|
text = [text]
|
|
|
|
|
|
for s in text:
|
|
|
out, i = [], 0
|
|
|
for m in self._re_placeholder.finditer(s):
|
|
|
out.append(s[i:m.start()])
|
|
|
if used >= len(images):
|
|
|
raise ValueError("not enough images for placeholders")
|
|
|
img = images[used]
|
|
|
used += 1
|
|
|
if m.group() == "<image_out>":
|
|
|
out.append(self.convert_image_to_token_string(img))
|
|
|
else:
|
|
|
out.append("<image>")
|
|
|
imgs_in.append(img)
|
|
|
i = m.end()
|
|
|
out.append(s[i:])
|
|
|
new_text.append("".join(out))
|
|
|
|
|
|
if used != len(images):
|
|
|
raise ValueError(f"too many images for placeholders. used {used} vs len(images) {len(images)}. {text}")
|
|
|
|
|
|
text_inputs = self.tokenizer(new_text, **output_kwargs["text_kwargs"])
|
|
|
image_inputs = self.image_processor.preprocess(imgs_in, **output_kwargs["images_kwargs"]) if imgs_in else {}
|
|
|
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
|
|
|
def batch_decode(self, sequences, *args, **kwargs):
|
|
|
return [self.decode(seq, *args, **kwargs)
|
|
|
for i, seq in enumerate(sequences)]
|
|
|
|
|
|
def decode(self, *args, **kwargs):
|
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
|
|
@property
|
|
|
def model_input_names(self):
|
|
|
tokenizer_input_names = self.tokenizer.model_input_names
|
|
|
image_processor_input_names = self.image_processor.model_input_names
|
|
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
|
|