Student0809 commited on
Commit
7feac49
·
verified ·
1 Parent(s): cb2428f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. docs/transformers/build/lib/transformers/models/chameleon/modeling_chameleon.py +1673 -0
  3. docs/transformers/build/lib/transformers/models/chameleon/processing_chameleon.py +177 -0
  4. docs/transformers/build/lib/transformers/models/chinese_clip/configuration_chinese_clip.py +434 -0
  5. docs/transformers/build/lib/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py +134 -0
  6. docs/transformers/build/lib/transformers/models/chinese_clip/feature_extraction_chinese_clip.py +38 -0
  7. docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip.py +314 -0
  8. docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +40 -0
  9. docs/transformers/build/lib/transformers/models/chinese_clip/modeling_chinese_clip.py +1630 -0
  10. docs/transformers/build/lib/transformers/models/chinese_clip/processing_chinese_clip.py +163 -0
  11. docs/transformers/build/lib/transformers/models/clap/__init__.py +29 -0
  12. docs/transformers/build/lib/transformers/models/clap/configuration_clap.py +394 -0
  13. docs/transformers/build/lib/transformers/models/clap/convert_clap_original_pytorch_to_hf.py +133 -0
  14. docs/transformers/build/lib/transformers/models/clap/feature_extraction_clap.py +367 -0
  15. docs/transformers/build/lib/transformers/models/clap/modeling_clap.py +0 -0
  16. docs/transformers/build/lib/transformers/models/clap/processing_clap.py +120 -0
  17. docs/transformers/build/lib/transformers/models/clip/__init__.py +35 -0
  18. docs/transformers/build/lib/transformers/models/clip/convert_clip_original_pytorch_to_hf.py +156 -0
  19. old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl +3 -0
  20. old/dataset_10k_train.jsonl +3 -0
  21. seamless_interaction/assets/banner.gif +3 -0
  22. swift/llm/template/__pycache__/vision_utils.cpython-310.pyc +0 -0
  23. swift/llm/template/template/__init__.py +2 -0
  24. swift/llm/template/template/__pycache__/__init__.cpython-310.pyc +0 -0
  25. swift/llm/template/template/__pycache__/deepseek.cpython-310.pyc +0 -0
  26. swift/llm/template/template/__pycache__/emu3.cpython-310.pyc +0 -0
  27. swift/llm/template/template/__pycache__/gemma.cpython-310.pyc +0 -0
  28. swift/llm/template/template/__pycache__/glm.cpython-310.pyc +0 -0
  29. swift/llm/template/template/__pycache__/idefics3.cpython-310.pyc +0 -0
  30. swift/llm/template/template/__pycache__/internlm.cpython-310.pyc +0 -0
  31. swift/llm/template/template/__pycache__/internvl.cpython-310.pyc +0 -0
  32. swift/llm/template/template/__pycache__/llama.cpython-310.pyc +0 -0
  33. swift/llm/template/template/__pycache__/llava.cpython-310.pyc +0 -0
  34. swift/llm/template/template/__pycache__/llm.cpython-310.pyc +0 -0
  35. swift/llm/template/template/__pycache__/megrez.cpython-310.pyc +0 -0
  36. swift/llm/template/template/__pycache__/microsoft.cpython-310.pyc +0 -0
  37. swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc +0 -0
  38. swift/llm/template/template/__pycache__/minimax.cpython-310.pyc +0 -0
  39. swift/llm/template/template/__pycache__/mistral.cpython-310.pyc +0 -0
  40. swift/llm/template/template/__pycache__/molmo.cpython-310.pyc +0 -0
  41. swift/llm/template/template/__pycache__/moonshot.cpython-310.pyc +0 -0
  42. swift/llm/template/template/__pycache__/mplug.cpython-310.pyc +0 -0
  43. swift/llm/template/template/__pycache__/openbuddy.cpython-310.pyc +0 -0
  44. swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc +0 -0
  45. swift/llm/template/template/__pycache__/qwen.cpython-310.pyc +0 -0
  46. swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc +0 -0
  47. swift/llm/template/template/__pycache__/utils.cpython-310.pyc +0 -0
  48. swift/llm/template/template/__pycache__/valley.cpython-310.pyc +0 -0
  49. swift/llm/template/template/__pycache__/yi.cpython-310.pyc +0 -0
  50. swift/llm/template/template/deepseek.py +315 -0
.gitattributes CHANGED
@@ -38,3 +38,13 @@ docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text
38
  docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
39
  docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png filter=lfs diff=lfs merge=lfs -text
40
  docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
38
  docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
39
  docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png filter=lfs diff=lfs merge=lfs -text
40
  docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png filter=lfs diff=lfs merge=lfs -text
41
+ old/dataset_10k_train.jsonl filter=lfs diff=lfs merge=lfs -text
42
+ old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl filter=lfs diff=lfs merge=lfs -text
43
+ wandb/offline-run-20250720_214625-3kgefhnp/run-3kgefhnp.wandb filter=lfs diff=lfs merge=lfs -text
44
+ wandb/offline-run-20250722_000857-dio4c8kj/run-dio4c8kj.wandb filter=lfs diff=lfs merge=lfs -text
45
+ wandb/offline-run-20250720_155533-1r0qjmiz/run-1r0qjmiz.wandb filter=lfs diff=lfs merge=lfs -text
46
+ wandb/offline-run-20250720_231916-zbtazovk/run-zbtazovk.wandb filter=lfs diff=lfs merge=lfs -text
47
+ wandb/offline-run-20250624_115955-iye05c18/run-iye05c18.wandb filter=lfs diff=lfs merge=lfs -text
48
+ wandb/offline-run-20250721_000454-up3efnok/run-up3efnok.wandb filter=lfs diff=lfs merge=lfs -text
49
+ wandb/offline-run-20250722_003110-femxkckf/run-femxkckf.wandb filter=lfs diff=lfs merge=lfs -text
50
+ seamless_interaction/assets/banner.gif filter=lfs diff=lfs merge=lfs -text
docs/transformers/build/lib/transformers/models/chameleon/modeling_chameleon.py ADDED
@@ -0,0 +1,1673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Chameleon model."""
16
+
17
+ import math
18
+ from functools import cached_property
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import Cache, DynamicCache, StaticCache
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
31
+ from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
32
+ from ...modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ )
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
38
+ from ...utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_torch_flex_attn_available,
43
+ is_torchdynamo_compiling,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
48
+
49
+
50
+ if is_torch_flex_attn_available():
51
+ from torch.nn.attention.flex_attention import BlockMask
52
+
53
+ from ...integrations.flex_attention import make_flex_block_causal_mask
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "ChameleonConfig"
59
+ _CHECKPOINT_FOR_DOC = "meta/chameleon-7b"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 7, 4096]
61
+ _SEQ_CLASS_EXPECTED_LOSS = 1.03
62
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
66
+ class ChameleonRMSNorm(nn.Module):
67
+ def __init__(self, hidden_size, eps=1e-6):
68
+ """
69
+ ChameleonRMSNorm is equivalent to T5LayerNorm
70
+ """
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.ones(hidden_size))
73
+ self.variance_epsilon = eps
74
+
75
+ def forward(self, hidden_states):
76
+ input_dtype = hidden_states.dtype
77
+ hidden_states = hidden_states.to(torch.float32)
78
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
+ return self.weight * hidden_states.to(input_dtype)
81
+
82
+ def extra_repr(self):
83
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
84
+
85
+
86
+ ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
87
+
88
+
89
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
90
+ # TODO(joao): add me back asap :)
91
+ class ChameleonRotaryEmbedding(nn.Module):
92
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
93
+ super().__init__()
94
+ self.scaling_factor = scaling_factor
95
+ self.dim = dim
96
+ self.max_position_embeddings = max_position_embeddings
97
+ self.base = base
98
+ inv_freq = 1.0 / (
99
+ self.base
100
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
101
+ )
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ # For BC we register cos and sin cached
104
+ self.max_seq_len_cached = max_position_embeddings
105
+
106
+ @torch.no_grad()
107
+ def forward(self, x, position_ids):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
110
+ position_ids_expanded = position_ids[:, None, :].float()
111
+ # Force float32 since bfloat16 loses precision on long contexts
112
+ # See https://github.com/huggingface/transformers/pull/29285
113
+ device_type = x.device.type
114
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
115
+ with torch.autocast(device_type=device_type, enabled=False):
116
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
117
+ emb = torch.cat((freqs, freqs), dim=-1)
118
+ cos = emb.cos()
119
+ sin = emb.sin()
120
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
121
+
122
+
123
+ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
124
+ """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
125
+
126
+ def forward(self, x, position_ids):
127
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
128
+ position_ids = position_ids.float() / self.scaling_factor
129
+ cos, sin = super().forward(x, position_ids)
130
+ return cos, sin
131
+
132
+
133
+ class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
134
+ """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
135
+
136
+ def forward(self, x, position_ids):
137
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
138
+ seq_len = torch.max(position_ids) + 1
139
+ if seq_len > self.max_position_embeddings:
140
+ base = self.base * (
141
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
142
+ ) ** (self.dim / (self.dim - 2))
143
+ inv_freq = 1.0 / (
144
+ base
145
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
146
+ )
147
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
148
+
149
+ cos, sin = super().forward(x, position_ids)
150
+ return cos, sin
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
+ """Applies Rotary Position Embedding to the query and key tensors.
164
+
165
+ Args:
166
+ q (`torch.Tensor`): The query tensor.
167
+ k (`torch.Tensor`): The key tensor.
168
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
170
+ position_ids (`torch.Tensor`, *optional*):
171
+ Deprecated and unused.
172
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
173
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
+ Returns:
180
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
+ """
182
+ cos = cos.unsqueeze(unsqueeze_dim)
183
+ sin = sin.unsqueeze(unsqueeze_dim)
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
190
+ class ChameleonMLP(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.config = config
194
+ self.hidden_size = config.hidden_size
195
+ self.intermediate_size = config.intermediate_size
196
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
197
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
198
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
199
+ self.act_fn = ACT2FN[config.hidden_act]
200
+
201
+ # Ignore copy
202
+ def forward(self, x):
203
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
204
+ return down_proj
205
+
206
+
207
+ class ChameleonLayerNorm(nn.LayerNorm):
208
+ """
209
+ LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
210
+ from each shard separately to each head, instead of reducing. We can apply each head's own
211
+ gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
212
+ in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
213
+ """
214
+
215
+ def __init__(self, hidden_size, *args, **kwargs):
216
+ super().__init__(hidden_size, *args, **kwargs)
217
+ self.normalized_shape = (hidden_size[-1],)
218
+
219
+ def forward(self, hidden_states):
220
+ hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
221
+ hidden_states = hidden_states * self.weight + self.bias
222
+ return hidden_states
223
+
224
+
225
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
226
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
227
+ """
228
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
229
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
230
+ """
231
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
232
+ if n_rep == 1:
233
+ return hidden_states
234
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
235
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
236
+
237
+
238
+ class ChameleonAttention(nn.Module):
239
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
240
+
241
+ def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None):
242
+ super().__init__()
243
+ self.config = config
244
+ self.layer_idx = layer_idx
245
+ if layer_idx is None:
246
+ logger.warning_once(
247
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
248
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
249
+ "when creating this class."
250
+ )
251
+
252
+ self.attention_dropout = config.attention_dropout
253
+ self.hidden_size = config.hidden_size
254
+ self.num_heads = config.num_attention_heads
255
+ self.head_dim = self.hidden_size // self.num_heads
256
+ self.num_key_value_heads = config.num_key_value_heads
257
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
258
+ self.max_position_embeddings = config.max_position_embeddings
259
+ self.rope_theta = config.rope_theta
260
+ self.is_causal = True
261
+ self.model_parallel_size = config.model_parallel_size
262
+
263
+ if (self.head_dim * self.num_heads) != self.hidden_size:
264
+ raise ValueError(
265
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
266
+ f" and `num_heads`: {self.num_heads})."
267
+ )
268
+
269
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
270
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
271
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
273
+ self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
274
+ self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
275
+ self._init_rope()
276
+
277
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
278
+ # TODO(joao): add me back asap :)
279
+ def _init_rope(self):
280
+ if self.config.rope_scaling is None:
281
+ self.rotary_emb = ChameleonRotaryEmbedding(
282
+ self.head_dim,
283
+ max_position_embeddings=self.max_position_embeddings,
284
+ base=self.rope_theta,
285
+ )
286
+ else:
287
+ scaling_type = self.config.rope_scaling["type"]
288
+ scaling_factor = self.config.rope_scaling["factor"]
289
+ if scaling_type == "linear":
290
+ self.rotary_emb = ChameleonLinearScalingRotaryEmbedding(
291
+ self.head_dim,
292
+ max_position_embeddings=self.max_position_embeddings,
293
+ scaling_factor=scaling_factor,
294
+ base=self.rope_theta,
295
+ )
296
+ elif scaling_type == "dynamic":
297
+ self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ scaling_factor=scaling_factor,
301
+ base=self.rope_theta,
302
+ )
303
+ else:
304
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ position_ids: Optional[torch.LongTensor] = None,
311
+ past_key_value: Optional[Cache] = None,
312
+ output_attentions: bool = False,
313
+ use_cache: bool = False,
314
+ cache_position: Optional[torch.LongTensor] = None,
315
+ **kwargs,
316
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
317
+ bsz, q_len, _ = hidden_states.size()
318
+
319
+ query_states = self.q_proj(hidden_states)
320
+ key_states = self.k_proj(hidden_states)
321
+ value_states = self.v_proj(hidden_states)
322
+
323
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
324
+ query_states = self.q_norm(query_states)
325
+
326
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
327
+ key_states = self.k_norm(key_states)
328
+
329
+ query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
330
+ key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
331
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
332
+
333
+ cos, sin = self.rotary_emb(value_states, position_ids)
334
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
335
+
336
+ if past_key_value is not None:
337
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
338
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
339
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
340
+
341
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
342
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
343
+
344
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
345
+
346
+ if attention_mask is not None: # no matter the length, we just slice it
347
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
348
+ attn_weights = attn_weights + causal_mask
349
+
350
+ # upcast attention to fp32
351
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
352
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
353
+ attn_output = torch.matmul(attn_weights, value_states)
354
+
355
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
356
+ raise ValueError(
357
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
358
+ f" {attn_output.size()}"
359
+ )
360
+
361
+ attn_output = attn_output.transpose(1, 2).contiguous()
362
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
363
+ attn_output = self.o_proj(attn_output)
364
+
365
+ if not output_attentions:
366
+ attn_weights = None
367
+
368
+ return attn_output, attn_weights, past_key_value
369
+
370
+
371
+ # NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
372
+ # TODO(joao): add me back asap :)
373
+ class ChameleonFlashAttention2(ChameleonAttention):
374
+ """
375
+ Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
376
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
377
+ flash attention and deal with padding tokens in case the input contains any of them.
378
+ """
379
+
380
+ def __init__(self, *args, **kwargs):
381
+ super().__init__(*args, **kwargs)
382
+
383
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
384
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
385
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
386
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
387
+
388
+ # Ignore copy
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ attention_mask: Optional[torch.LongTensor] = None,
393
+ position_ids: Optional[torch.LongTensor] = None,
394
+ past_key_value: Optional[Cache] = None,
395
+ output_attentions: bool = False,
396
+ use_cache: bool = False,
397
+ cache_position: Optional[torch.LongTensor] = None,
398
+ **kwargs,
399
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
400
+ if isinstance(past_key_value, StaticCache):
401
+ raise ValueError(
402
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
403
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
404
+ )
405
+
406
+ output_attentions = False
407
+
408
+ bsz, q_len, _ = hidden_states.size()
409
+
410
+ query_states = self.q_proj(hidden_states)
411
+ key_states = self.k_proj(hidden_states)
412
+ value_states = self.v_proj(hidden_states)
413
+
414
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
415
+ query_states = self.q_norm(query_states)
416
+
417
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
418
+ key_states = self.k_norm(key_states)
419
+
420
+ # Flash attention requires the input to have the shape
421
+ # batch_size x seq_length x head_dim x hidden_dim
422
+ # therefore we just need to keep the original shape
423
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
424
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
425
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
426
+
427
+ cos, sin = self.rotary_emb(value_states, position_ids)
428
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
429
+
430
+ if past_key_value is not None:
431
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
432
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
433
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
434
+
435
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
436
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
437
+ query_states = query_states.transpose(1, 2)
438
+ key_states = key_states.transpose(1, 2)
439
+ value_states = value_states.transpose(1, 2)
440
+
441
+ dropout_rate = self.attention_dropout if self.training else 0.0
442
+
443
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
444
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
445
+ # cast them back in the correct dtype just to be sure everything works as expected.
446
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
447
+ # in fp32. (ChameleonRMSNorm handles it correctly)
448
+
449
+ input_dtype = query_states.dtype
450
+ if input_dtype == torch.float32:
451
+ if torch.is_autocast_enabled():
452
+ target_dtype = torch.get_autocast_gpu_dtype()
453
+ # Handle the case where the model is quantized
454
+ elif hasattr(self.config, "_pre_quantization_dtype"):
455
+ target_dtype = self.config._pre_quantization_dtype
456
+ else:
457
+ target_dtype = self.q_proj.weight.dtype
458
+
459
+ logger.warning_once(
460
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
461
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
462
+ f" {target_dtype}."
463
+ )
464
+
465
+ query_states = query_states.to(target_dtype)
466
+ key_states = key_states.to(target_dtype)
467
+ value_states = value_states.to(target_dtype)
468
+
469
+ attn_output = _flash_attention_forward(
470
+ query_states,
471
+ key_states,
472
+ value_states,
473
+ attention_mask,
474
+ q_len,
475
+ dropout=dropout_rate,
476
+ sliding_window=getattr(self, "sliding_window", None),
477
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
478
+ is_causal=self.is_causal,
479
+ )
480
+
481
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
482
+ attn_output = self.o_proj(attn_output)
483
+
484
+ if not output_attentions:
485
+ attn_weights = None
486
+
487
+ return attn_output, attn_weights, past_key_value
488
+
489
+
490
+ class ChameleonSdpaAttention(ChameleonAttention):
491
+ """
492
+ Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
493
+ `ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
494
+ SDPA API.
495
+ """
496
+
497
+ # Adapted from ChameleonAttention.forward
498
+ def forward(
499
+ self,
500
+ hidden_states: torch.Tensor,
501
+ attention_mask: Optional[torch.Tensor] = None,
502
+ position_ids: Optional[torch.LongTensor] = None,
503
+ past_key_value: Optional[Cache] = None,
504
+ output_attentions: bool = False,
505
+ use_cache: bool = False,
506
+ cache_position: Optional[torch.LongTensor] = None,
507
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
508
+ if output_attentions:
509
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
510
+ logger.warning_once(
511
+ "ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
512
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
513
+ )
514
+ return super().forward(
515
+ hidden_states=hidden_states,
516
+ attention_mask=attention_mask,
517
+ position_ids=position_ids,
518
+ past_key_value=past_key_value,
519
+ output_attentions=output_attentions,
520
+ use_cache=use_cache,
521
+ cache_position=cache_position,
522
+ )
523
+
524
+ bsz, q_len, _ = hidden_states.size()
525
+
526
+ query_states = self.q_proj(hidden_states)
527
+ key_states = self.k_proj(hidden_states)
528
+ value_states = self.v_proj(hidden_states)
529
+
530
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
531
+ query_states = self.q_norm(query_states)
532
+
533
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
534
+ key_states = self.k_norm(key_states)
535
+
536
+ query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
537
+ key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
538
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
539
+
540
+ cos, sin = self.rotary_emb(value_states, position_ids)
541
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
542
+
543
+ if past_key_value is not None:
544
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
545
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
546
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
547
+
548
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
549
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
550
+
551
+ causal_mask = attention_mask
552
+ if attention_mask is not None and cache_position is not None:
553
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
554
+
555
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
556
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
557
+ if query_states.device.type == "cuda" and causal_mask is not None:
558
+ query_states = query_states.contiguous()
559
+ key_states = key_states.contiguous()
560
+ value_states = value_states.contiguous()
561
+
562
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
563
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
564
+ is_causal = True if causal_mask is None and q_len > 1 else False
565
+
566
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
567
+ query_states,
568
+ key_states,
569
+ value_states,
570
+ attn_mask=causal_mask,
571
+ dropout_p=self.attention_dropout if self.training else 0.0,
572
+ is_causal=is_causal,
573
+ )
574
+
575
+ attn_output = attn_output.transpose(1, 2).contiguous()
576
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
577
+
578
+ attn_output = self.o_proj(attn_output)
579
+
580
+ return attn_output, None, past_key_value
581
+
582
+
583
+ CHAMELEON_ATTENTION_CLASSES = {
584
+ "eager": ChameleonAttention,
585
+ "flash_attention_2": ChameleonFlashAttention2,
586
+ "sdpa": ChameleonSdpaAttention,
587
+ }
588
+
589
+
590
+ # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
591
+ # TODO(joao): add me back asap :)
592
+ class ChameleonDecoderLayer(nn.Module):
593
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
594
+ super().__init__()
595
+ self.hidden_size = config.hidden_size
596
+
597
+ self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
598
+
599
+ self.mlp = ChameleonMLP(config)
600
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+
603
+ def forward(
604
+ self,
605
+ hidden_states: torch.Tensor,
606
+ attention_mask: Optional[torch.Tensor] = None,
607
+ position_ids: Optional[torch.LongTensor] = None,
608
+ past_key_value: Optional[Cache] = None,
609
+ output_attentions: Optional[bool] = False,
610
+ use_cache: Optional[bool] = False,
611
+ cache_position: Optional[torch.LongTensor] = None,
612
+ **kwargs,
613
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
617
+ attention_mask (`torch.FloatTensor`, *optional*):
618
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
619
+ query_sequence_length, key_sequence_length)` if default attention is used.
620
+ output_attentions (`bool`, *optional*):
621
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
622
+ returned tensors for more detail.
623
+ use_cache (`bool`, *optional*):
624
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
625
+ (see `past_key_values`).
626
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
627
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
628
+ Indices depicting the position of the input sequence tokens in the sequence
629
+ kwargs (`dict`, *optional*):
630
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
631
+ into the model
632
+ """
633
+ residual = hidden_states
634
+
635
+ hidden_states = self.input_layernorm(hidden_states)
636
+
637
+ # Self Attention
638
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
639
+ hidden_states=hidden_states,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_value=past_key_value,
643
+ output_attentions=output_attentions,
644
+ use_cache=use_cache,
645
+ cache_position=cache_position,
646
+ **kwargs,
647
+ )
648
+ hidden_states = residual + hidden_states
649
+
650
+ # Fully Connected
651
+ residual = hidden_states
652
+ hidden_states = self.post_attention_layernorm(hidden_states)
653
+ hidden_states = self.mlp(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+
656
+ outputs = (hidden_states,)
657
+
658
+ if output_attentions:
659
+ outputs += (self_attn_weights,)
660
+
661
+ if use_cache:
662
+ outputs += (present_key_value,)
663
+
664
+ return outputs
665
+
666
+
667
+ class ChameleonSwinDecoderLayer(nn.Module):
668
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
669
+ super().__init__()
670
+ self.hidden_size = config.hidden_size
671
+
672
+ self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
673
+
674
+ self.mlp = ChameleonMLP(config)
675
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
676
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ position_ids: Optional[torch.LongTensor] = None,
683
+ past_key_value: Optional[Cache] = None,
684
+ output_attentions: Optional[bool] = False,
685
+ use_cache: Optional[bool] = False,
686
+ cache_position: Optional[torch.LongTensor] = None,
687
+ **kwargs,
688
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
689
+ """
690
+ Args:
691
+ hidden_states (`torch.FloatTensor`):
692
+ input to the layer of shape `(batch, seq_len, embed_dim)`
693
+ attention_mask (`torch.FloatTensor`, *optional*):
694
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
695
+ query_sequence_length, key_sequence_length)` if default attention is used.
696
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
697
+ Indices of positions of each input sequence tokens in the position embeddings
698
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
699
+ output_attentions (`bool`, *optional*):
700
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
701
+ returned tensors for more detail.
702
+ use_cache (`bool`, *optional*):
703
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
704
+ (see `past_key_values`).
705
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
706
+ Indices depicting the position of the input sequence tokens in the sequence.
707
+ """
708
+
709
+ residual = hidden_states
710
+
711
+ # Self Attention
712
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
713
+ hidden_states=hidden_states,
714
+ attention_mask=attention_mask,
715
+ position_ids=position_ids,
716
+ past_key_value=past_key_value,
717
+ output_attentions=output_attentions,
718
+ use_cache=use_cache,
719
+ cache_position=cache_position,
720
+ **kwargs,
721
+ )
722
+ hidden_states = self.input_layernorm(hidden_states)
723
+ hidden_states = residual + hidden_states
724
+ # Fully Connected
725
+ residual = hidden_states
726
+ hidden_states = self.mlp(hidden_states)
727
+ hidden_states = self.post_attention_layernorm(hidden_states)
728
+ hidden_states = residual + hidden_states
729
+ outputs = (hidden_states,)
730
+
731
+ if output_attentions:
732
+ outputs += (self_attn_weights,)
733
+
734
+ if use_cache:
735
+ outputs += (present_key_value,)
736
+
737
+ return outputs
738
+
739
+
740
+ class ChameleonVQVAEVectorQuantizer(nn.Module):
741
+ """
742
+ A module for vector quantization using learned embedding vectors.
743
+
744
+ This module implements the quantization process similar to te one described in
745
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
746
+ input vectors into discrete codebook vectors, which are learned during training.
747
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
748
+ and allowing for post-hoc remapping of indices.
749
+ """
750
+
751
+ def __init__(self, config):
752
+ super().__init__()
753
+ self.num_embeddings = config.num_embeddings
754
+ self.embedding_dim = config.embed_dim
755
+ self.beta = getattr(config, "beta", 0.25)
756
+
757
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
758
+
759
+ def forward(self, hidden_state: torch.Tensor):
760
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
761
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
762
+
763
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
764
+ distances = (
765
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
766
+ + torch.sum(self.embedding.weight**2, dim=1)
767
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
768
+ )
769
+
770
+ min_encoding_indices = torch.argmin(distances, dim=1)
771
+ hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
772
+
773
+ # compute loss for embedding
774
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
775
+ (hidden_state_quant - hidden_state.detach()) ** 2
776
+ )
777
+
778
+ # preserve gradients
779
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
780
+
781
+ # reshape back to match original input shape
782
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
783
+
784
+ return hidden_state_quant, loss, min_encoding_indices
785
+
786
+
787
+ class ChameleonVQVAEEncoderConvDownsample(nn.Module):
788
+ def __init__(self, in_channels):
789
+ super().__init__()
790
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
791
+
792
+ def forward(self, hidden_states):
793
+ # no asymmetric padding in torch conv, must do it ourselves
794
+ hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
795
+ hidden_states = self.conv(hidden_states)
796
+ return hidden_states
797
+
798
+
799
+ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
800
+ def __init__(
801
+ self,
802
+ config,
803
+ in_channels,
804
+ out_channels=None,
805
+ conv_shortcut=False,
806
+ ):
807
+ super().__init__()
808
+ self.in_channels = in_channels
809
+ self.out_channels = in_channels if out_channels is None else out_channels
810
+ self.use_conv_shortcut = conv_shortcut
811
+
812
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
813
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
814
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
815
+ self.dropout = torch.nn.Dropout(config.dropout)
816
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
817
+ if self.in_channels != self.out_channels:
818
+ if self.use_conv_shortcut:
819
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
820
+ else:
821
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
822
+
823
+ def forward(self, hidden_states):
824
+ residual = hidden_states
825
+ hidden_states = self.norm1(hidden_states)
826
+ hidden_states *= torch.sigmoid(hidden_states)
827
+ hidden_states = self.conv1(hidden_states)
828
+
829
+ hidden_states = self.norm2(hidden_states)
830
+ hidden_states *= torch.sigmoid(hidden_states)
831
+ hidden_states = self.dropout(hidden_states)
832
+ hidden_states = self.conv2(hidden_states)
833
+
834
+ if self.in_channels != self.out_channels:
835
+ if self.use_conv_shortcut:
836
+ residual = self.conv_shortcut(residual)
837
+ else:
838
+ residual = self.nin_shortcut(residual)
839
+
840
+ return residual + hidden_states
841
+
842
+
843
+ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
844
+ def __init__(self, in_channels):
845
+ super().__init__()
846
+ self.in_channels = in_channels
847
+
848
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
849
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
850
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
851
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
852
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
853
+
854
+ def forward(self, hidden_states):
855
+ residual = hidden_states
856
+ hidden_states = self.norm(hidden_states)
857
+ query_states = self.q(hidden_states)
858
+ key_states = self.k(hidden_states)
859
+ value_states = self.v(hidden_states)
860
+
861
+ # compute attention
862
+ batch_size, channels, height, width = query_states.shape
863
+ query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
864
+ key_states = key_states.reshape(batch_size, channels, height * width)
865
+ attn_weights = torch.bmm(query_states, key_states)
866
+ attn_weights = attn_weights * (int(channels) ** (-0.5))
867
+ attn_weights = F.softmax(attn_weights, dim=2)
868
+
869
+ # attend to values
870
+ value_states = value_states.reshape(batch_size, channels, height * width)
871
+ attn_weights = attn_weights.permute(0, 2, 1)
872
+ attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
873
+
874
+ attn_output = self.proj_out(attn_output)
875
+ return residual + attn_output
876
+
877
+
878
+ class ChameleonVQVAEEncoder(nn.Module):
879
+ def __init__(self, config):
880
+ super().__init__()
881
+
882
+ self.num_resolutions = len(config.channel_multiplier)
883
+ self.num_res_blocks = config.num_res_blocks
884
+ base_channels = config.base_channels
885
+ resolution = config.resolution
886
+ in_channels = config.in_channels
887
+ double_latent = config.double_latent
888
+ latent_channels = config.latent_channels
889
+ channel_multiplier = config.channel_multiplier
890
+
891
+ self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
892
+
893
+ curr_res = resolution
894
+ in_channel_multiplier = (1,) + tuple(channel_multiplier)
895
+ self.in_channel_multiplier = in_channel_multiplier
896
+ self.down = nn.ModuleList()
897
+ for i_level in range(self.num_resolutions):
898
+ block = nn.ModuleList()
899
+ attn = nn.ModuleList()
900
+ block_in = base_channels * in_channel_multiplier[i_level]
901
+ block_out = base_channels * channel_multiplier[i_level]
902
+ for i_block in range(self.num_res_blocks):
903
+ block.append(
904
+ ChameleonVQVAEEncoderResnetBlock(
905
+ config=config,
906
+ in_channels=block_in,
907
+ out_channels=block_out,
908
+ )
909
+ )
910
+ block_in = block_out
911
+ if (
912
+ config.attn_resolutions is not None
913
+ and curr_res in config.attn_resolutions
914
+ and config.attn_type == "vanilla"
915
+ ):
916
+ attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
917
+
918
+ down = nn.Module()
919
+ down.block = block
920
+ down.attn = attn
921
+ if i_level != self.num_resolutions - 1:
922
+ down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
923
+ curr_res = curr_res // 2
924
+ self.down.append(down)
925
+
926
+ self.mid = nn.Module()
927
+ self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
928
+ config=config,
929
+ in_channels=block_in,
930
+ out_channels=block_in,
931
+ )
932
+ self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
933
+ self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
934
+ config=config,
935
+ in_channels=block_in,
936
+ out_channels=block_in,
937
+ )
938
+
939
+ self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
940
+ self.conv_out = torch.nn.Conv2d(
941
+ block_in,
942
+ 2 * latent_channels if double_latent else latent_channels,
943
+ kernel_size=3,
944
+ stride=1,
945
+ padding=1,
946
+ )
947
+
948
+ def forward(self, pixel_values: torch.LongTensor):
949
+ # downsampling
950
+ hidden_states = [self.conv_in(pixel_values)]
951
+ for i_level in range(self.num_resolutions):
952
+ for i_block in range(self.num_res_blocks):
953
+ hidden_state = self.down[i_level].block[i_block](
954
+ hidden_states[-1],
955
+ )
956
+ if len(self.down[i_level].attn) > 0:
957
+ hidden_state = self.down[i_level].attn[i_block](hidden_state)
958
+ hidden_states.append(hidden_state)
959
+ if i_level != self.num_resolutions - 1:
960
+ hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
961
+
962
+ # middle
963
+ last_hidden_state = hidden_states[-1]
964
+ last_hidden_state = self.mid.block_1(last_hidden_state)
965
+ last_hidden_state = self.mid.attn_1(last_hidden_state)
966
+ last_hidden_state = self.mid.block_2(last_hidden_state)
967
+
968
+ # end
969
+ last_hidden_state = self.norm_out(last_hidden_state)
970
+ last_hidden_state *= torch.sigmoid(last_hidden_state)
971
+ last_hidden_state = self.conv_out(last_hidden_state)
972
+ return last_hidden_state
973
+
974
+
975
+ class ChameleonImageVocabularyMapping:
976
+ """
977
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
978
+ """
979
+
980
+ def __init__(self, vocab_map):
981
+ self.vocab_map = vocab_map
982
+ self.image_token_id = vocab_map.get("<image>")
983
+
984
+ @cached_property
985
+ def val2name(self):
986
+ return {v: k for k, v in self.vocab_map.items()}
987
+
988
+ @cached_property
989
+ def image_tokens(self):
990
+ return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
991
+
992
+ @cached_property
993
+ def bpe2img(self):
994
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
995
+
996
+ def remap(old_name: str) -> str:
997
+ return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
998
+
999
+ return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
1000
+
1001
+ @cached_property
1002
+ def img2bpe(self):
1003
+ return {v: k for k, v in self.bpe2img.items()}
1004
+
1005
+ @cached_property
1006
+ def bpe2img_search_tensors(self):
1007
+ return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
1008
+
1009
+ @cached_property
1010
+ def img2bpe_mapping_tensor(self):
1011
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
1012
+ for k, v in self.img2bpe.items():
1013
+ mapping[k] = v
1014
+ return mapping
1015
+
1016
+ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
1017
+ device = img_batch.device
1018
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
1019
+ return img_tokens.to(device)
1020
+
1021
+
1022
+ CHAMELEON_START_DOCSTRING = r"""
1023
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1024
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1025
+ etc.)
1026
+
1027
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1028
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1029
+ and behavior.
1030
+
1031
+ Parameters:
1032
+ config ([`ChameleonConfig`]):
1033
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1034
+ load the weights associated with the model, only the configuration. Check out the
1035
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1036
+ """
1037
+
1038
+
1039
+ @add_start_docstrings(
1040
+ "The bare chameleon Model outputting raw hidden-states without any specific head on top.",
1041
+ CHAMELEON_START_DOCSTRING,
1042
+ )
1043
+ class ChameleonPreTrainedModel(PreTrainedModel):
1044
+ config_class = ChameleonConfig
1045
+ base_model_prefix = "model"
1046
+ supports_gradient_checkpointing = True
1047
+ _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
1048
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
1049
+ _supports_flash_attn_2 = True
1050
+ _supports_sdpa = True
1051
+ _supports_quantized_cache = True
1052
+ _supports_cache_class = True
1053
+ _supports_static_cache = True
1054
+ _supports_param_buffer_assignment = False
1055
+
1056
+ def _init_weights(self, module):
1057
+ std = self.config.initializer_range
1058
+
1059
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
1060
+ module.weight.data.normal_(mean=0.0, std=std)
1061
+ if module.bias is not None:
1062
+ module.bias.data.zero_()
1063
+ elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
1064
+ module.bias.data.zero_()
1065
+ module.weight.data.fill_(1.0)
1066
+ elif isinstance(module, ChameleonRMSNorm):
1067
+ module.weight.data.fill_(1.0)
1068
+ elif isinstance(module, nn.Embedding):
1069
+ module.weight.data.normal_(mean=0.0, std=std)
1070
+ if module.padding_idx is not None:
1071
+ module.weight.data[module.padding_idx].zero_()
1072
+
1073
+
1074
+ CHAMELEON_VQ_START_DOCSTRING = r"""
1075
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1076
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1077
+ etc.)
1078
+
1079
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1080
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1081
+ and behavior.
1082
+
1083
+ Parameters:
1084
+ config ([`ChameleonVQVAEConfig`]):
1085
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
+ load the weights associated with the model, only the configuration. Check out the
1087
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1088
+ """
1089
+
1090
+
1091
+ @add_start_docstrings(
1092
+ """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
1093
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
1094
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
1095
+ """,
1096
+ CHAMELEON_VQ_START_DOCSTRING,
1097
+ )
1098
+ class ChameleonVQVAE(ChameleonPreTrainedModel):
1099
+ config_class = ChameleonVQVAEConfig
1100
+ _no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
1101
+
1102
+ def __init__(self, config: ChameleonVQVAEConfig):
1103
+ super().__init__(config)
1104
+
1105
+ self.encoder = ChameleonVQVAEEncoder(config)
1106
+ self.quantize = ChameleonVQVAEVectorQuantizer(config)
1107
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
1108
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
1109
+ self.eval() # Chameleon's VQ model is frozen
1110
+
1111
+ def encode(self, pixel_values: torch.LongTensor):
1112
+ hidden_states = self.encoder(pixel_values)
1113
+ hidden_states = self.quant_conv(hidden_states)
1114
+ quant, emb_loss, indices = self.quantize(hidden_states)
1115
+ return quant, emb_loss, indices
1116
+
1117
+
1118
+ CHAMELEON_INPUTS_DOCSTRING = r"""
1119
+ Args:
1120
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1121
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1122
+ it.
1123
+
1124
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1125
+ [`PreTrainedTokenizer.__call__`] for details.
1126
+
1127
+ [What are input IDs?](../glossary#input-ids)
1128
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1129
+ The tensors corresponding to the input images. Pixel values can be obtained using
1130
+ [`AutoImageProcessor`]. See [`ChameleonImageProcessor.__call__`] for details.
1131
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1132
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1133
+
1134
+ - 1 for tokens that are **not masked**,
1135
+ - 0 for tokens that are **masked**.
1136
+
1137
+ [What are attention masks?](../glossary#attention-mask)
1138
+
1139
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1140
+ [`PreTrainedTokenizer.__call__`] for details.
1141
+
1142
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1143
+ `past_key_values`).
1144
+
1145
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1146
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1147
+ information on the default strategy.
1148
+
1149
+ - 1 indicates the head is **not masked**,
1150
+ - 0 indicates the head is **masked**.
1151
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1153
+ config.n_positions - 1]`.
1154
+
1155
+ [What are position IDs?](../glossary#position-ids)
1156
+ past_key_values (`Cache`, *optional*):
1157
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1158
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1159
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1160
+
1161
+ Should always be a [`~cache_utils.Cache`] instance and the model will output the same cache instance.
1162
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1163
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1164
+ of shape `(batch_size, sequence_length)`.
1165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1167
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1168
+ model's internal embedding lookup matrix.
1169
+ use_cache (`bool`, *optional*):
1170
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1171
+ `past_key_values`).
1172
+ output_attentions (`bool`, *optional*):
1173
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1174
+ tensors for more detail.
1175
+ output_hidden_states (`bool`, *optional*):
1176
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1177
+ more detail.
1178
+ return_dict (`bool`, *optional*):
1179
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1180
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1181
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1182
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1183
+ the complete sequence length.
1184
+ """
1185
+
1186
+
1187
+ @add_start_docstrings(
1188
+ "The bare chameleon Model outputting raw hidden-states without any specific head on top.",
1189
+ CHAMELEON_START_DOCSTRING,
1190
+ )
1191
+ class ChameleonModel(ChameleonPreTrainedModel):
1192
+ """
1193
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ChameleonDecoderLayer`]
1194
+
1195
+ Args:
1196
+ config: ChameleonConfig
1197
+ """
1198
+
1199
+ def __init__(self, config: ChameleonConfig):
1200
+ super().__init__(config)
1201
+ self.padding_idx = config.pad_token_id
1202
+ self.vocab_size = config.vocab_size
1203
+
1204
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1205
+ self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
1206
+ decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
1207
+ self.layers = nn.ModuleList(
1208
+ [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1209
+ )
1210
+ self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1211
+ self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
1212
+ self.gradient_checkpointing = False
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ self.post_init()
1216
+
1217
+ def get_input_embeddings(self):
1218
+ return self.embed_tokens
1219
+
1220
+ def set_input_embeddings(self, value):
1221
+ self.embed_tokens = value
1222
+
1223
+ def get_image_tokens(self, pixel_values: torch.FloatTensor):
1224
+ """
1225
+ Tokenizes images into discrete tokens with VQGAN module. Converts
1226
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
1227
+ special tokens.
1228
+
1229
+ Args:
1230
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1231
+ The tensors corresponding to the input images.
1232
+ """
1233
+ batch_size = pixel_values.shape[0]
1234
+ _, _, image_toks = self.vqmodel.encode(pixel_values)
1235
+ bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
1236
+ bpe_toks = bpe_toks.view(batch_size, -1)
1237
+ return bpe_toks
1238
+
1239
+ @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
1240
+ @add_code_sample_docstrings(
1241
+ checkpoint=_CHECKPOINT_FOR_DOC,
1242
+ output_type=BaseModelOutputWithPast,
1243
+ config_class=_CONFIG_FOR_DOC,
1244
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1245
+ )
1246
+ def forward(
1247
+ self,
1248
+ input_ids: Optional[torch.LongTensor] = None,
1249
+ pixel_values: Optional[torch.FloatTensor] = None,
1250
+ attention_mask: Optional[torch.Tensor] = None,
1251
+ position_ids: Optional[torch.LongTensor] = None,
1252
+ past_key_values: Optional[Cache] = None,
1253
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1254
+ use_cache: Optional[bool] = None,
1255
+ output_attentions: Optional[bool] = None,
1256
+ output_hidden_states: Optional[bool] = None,
1257
+ return_dict: Optional[bool] = None,
1258
+ cache_position: Optional[torch.LongTensor] = None,
1259
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1260
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1261
+ output_hidden_states = (
1262
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1263
+ )
1264
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1265
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1266
+
1267
+ if self.gradient_checkpointing and self.training and use_cache:
1268
+ logger.warning_once(
1269
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1270
+ )
1271
+ use_cache = False
1272
+
1273
+ if (input_ids is None) ^ (inputs_embeds is not None):
1274
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1275
+
1276
+ if pixel_values is not None and inputs_embeds is not None:
1277
+ raise ValueError(
1278
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
1279
+ )
1280
+
1281
+ if pixel_values is not None:
1282
+ image_tokens = self.get_image_tokens(pixel_values)
1283
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
1284
+ if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
1285
+ n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
1286
+ n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
1287
+ raise ValueError(
1288
+ f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
1289
+ )
1290
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
1291
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
1292
+
1293
+ if inputs_embeds is None:
1294
+ inputs_embeds = self.embed_tokens(input_ids)
1295
+
1296
+ # torch.jit.trace() doesn't support cache objects in the output
1297
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
1298
+ past_key_values = DynamicCache()
1299
+
1300
+ if cache_position is None:
1301
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1302
+ cache_position = torch.arange(
1303
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1304
+ )
1305
+
1306
+ if position_ids is None:
1307
+ position_ids = cache_position.unsqueeze(0)
1308
+
1309
+ causal_mask = self._update_causal_mask(
1310
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1311
+ )
1312
+
1313
+ # embed positions
1314
+ hidden_states = inputs_embeds
1315
+
1316
+ # decoder layers
1317
+ all_hidden_states = () if output_hidden_states else None
1318
+ all_self_attns = () if output_attentions else None
1319
+ next_decoder_cache = None
1320
+
1321
+ for decoder_layer in self.layers:
1322
+ if output_hidden_states:
1323
+ all_hidden_states += (hidden_states,)
1324
+
1325
+ if self.gradient_checkpointing and self.training:
1326
+ layer_outputs = self._gradient_checkpointing_func(
1327
+ decoder_layer.__call__,
1328
+ hidden_states,
1329
+ causal_mask,
1330
+ position_ids,
1331
+ past_key_values,
1332
+ output_attentions,
1333
+ use_cache,
1334
+ cache_position,
1335
+ )
1336
+ else:
1337
+ layer_outputs = decoder_layer(
1338
+ hidden_states,
1339
+ attention_mask=causal_mask,
1340
+ position_ids=position_ids,
1341
+ past_key_value=past_key_values,
1342
+ output_attentions=output_attentions,
1343
+ use_cache=use_cache,
1344
+ cache_position=cache_position,
1345
+ )
1346
+
1347
+ hidden_states = layer_outputs[0]
1348
+
1349
+ if use_cache:
1350
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1351
+
1352
+ if output_attentions:
1353
+ all_self_attns += (layer_outputs[1],)
1354
+
1355
+ hidden_states = self.norm(hidden_states)
1356
+
1357
+ # add hidden states from the last decoder layer
1358
+ if output_hidden_states:
1359
+ all_hidden_states += (hidden_states,)
1360
+
1361
+ next_cache = None
1362
+ if use_cache:
1363
+ next_cache = next_decoder_cache
1364
+
1365
+ if not return_dict:
1366
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1367
+
1368
+ return BaseModelOutputWithPast(
1369
+ last_hidden_state=hidden_states,
1370
+ past_key_values=next_cache,
1371
+ hidden_states=all_hidden_states,
1372
+ attentions=all_self_attns,
1373
+ )
1374
+
1375
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1376
+ def _update_causal_mask(
1377
+ self,
1378
+ attention_mask: Union[torch.Tensor, "BlockMask"],
1379
+ input_tensor: torch.Tensor,
1380
+ cache_position: torch.Tensor,
1381
+ past_key_values: Cache,
1382
+ output_attentions: bool = False,
1383
+ ):
1384
+ if self.config._attn_implementation == "flash_attention_2":
1385
+ if attention_mask is not None and (attention_mask == 0.0).any():
1386
+ return attention_mask
1387
+ return None
1388
+ if self.config._attn_implementation == "flex_attention":
1389
+ if isinstance(attention_mask, torch.Tensor):
1390
+ attention_mask = make_flex_block_causal_mask(attention_mask)
1391
+ return attention_mask
1392
+
1393
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1394
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1395
+ # to infer the attention mask.
1396
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1397
+ using_static_cache = isinstance(past_key_values, StaticCache)
1398
+
1399
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1400
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1401
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1402
+ attention_mask,
1403
+ inputs_embeds=input_tensor,
1404
+ past_key_values_length=past_seen_tokens,
1405
+ is_training=self.training,
1406
+ ):
1407
+ return None
1408
+
1409
+ dtype, device = input_tensor.dtype, input_tensor.device
1410
+ sequence_length = input_tensor.shape[1]
1411
+ if using_static_cache:
1412
+ target_length = past_key_values.get_max_cache_shape()
1413
+ else:
1414
+ target_length = (
1415
+ attention_mask.shape[-1]
1416
+ if isinstance(attention_mask, torch.Tensor)
1417
+ else past_seen_tokens + sequence_length + 1
1418
+ )
1419
+
1420
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1421
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1422
+ attention_mask,
1423
+ sequence_length=sequence_length,
1424
+ target_length=target_length,
1425
+ dtype=dtype,
1426
+ device=device,
1427
+ cache_position=cache_position,
1428
+ batch_size=input_tensor.shape[0],
1429
+ )
1430
+
1431
+ if (
1432
+ self.config._attn_implementation == "sdpa"
1433
+ and attention_mask is not None
1434
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1435
+ and not output_attentions
1436
+ ):
1437
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1438
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1439
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1440
+ min_dtype = torch.finfo(dtype).min
1441
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1442
+
1443
+ return causal_mask
1444
+
1445
+ @staticmethod
1446
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
1447
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1448
+ attention_mask: torch.Tensor,
1449
+ sequence_length: int,
1450
+ target_length: int,
1451
+ dtype: torch.dtype,
1452
+ device: torch.device,
1453
+ cache_position: torch.Tensor,
1454
+ batch_size: int,
1455
+ **kwargs,
1456
+ ):
1457
+ """
1458
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1459
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1460
+
1461
+ Args:
1462
+ attention_mask (`torch.Tensor`):
1463
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1464
+ `(batch_size, 1, query_length, key_value_length)`.
1465
+ sequence_length (`int`):
1466
+ The sequence length being processed.
1467
+ target_length (`int`):
1468
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1469
+ to account for the 0 padding, the part of the cache that is not filled yet.
1470
+ dtype (`torch.dtype`):
1471
+ The dtype to use for the 4D attention mask.
1472
+ device (`torch.device`):
1473
+ The device to place the 4D attention mask on.
1474
+ cache_position (`torch.Tensor`):
1475
+ Indices depicting the position of the input sequence tokens in the sequence.
1476
+ batch_size (`torch.Tensor`):
1477
+ Batch size.
1478
+ """
1479
+ if attention_mask is not None and attention_mask.dim() == 4:
1480
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1481
+ causal_mask = attention_mask
1482
+ else:
1483
+ min_dtype = torch.finfo(dtype).min
1484
+ causal_mask = torch.full(
1485
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1486
+ )
1487
+ if sequence_length != 1:
1488
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1489
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1490
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1491
+ if attention_mask is not None:
1492
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1493
+ mask_length = attention_mask.shape[-1]
1494
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1495
+ causal_mask.device
1496
+ )
1497
+ padding_mask = padding_mask == 0
1498
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1499
+ padding_mask, min_dtype
1500
+ )
1501
+
1502
+ return causal_mask
1503
+
1504
+
1505
+ @add_start_docstrings(
1506
+ "Chameleon Model with a head on top used for outputting logits for next token prediction.",
1507
+ CHAMELEON_START_DOCSTRING,
1508
+ )
1509
+ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
1510
+ _tied_weights_keys = ["lm_head.weight"]
1511
+
1512
+ def __init__(self, config):
1513
+ super().__init__(config)
1514
+ self.model = ChameleonModel(config)
1515
+ self.vocab_size = config.vocab_size
1516
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1517
+
1518
+ # Initialize weights and apply final processing
1519
+ self.post_init()
1520
+
1521
+ def get_input_embeddings(self):
1522
+ return self.model.embed_tokens
1523
+
1524
+ def set_input_embeddings(self, value):
1525
+ self.model.embed_tokens = value
1526
+
1527
+ def get_output_embeddings(self):
1528
+ return self.lm_head
1529
+
1530
+ def set_output_embeddings(self, new_embeddings):
1531
+ self.lm_head = new_embeddings
1532
+
1533
+ def set_decoder(self, decoder):
1534
+ self.model = decoder
1535
+
1536
+ def get_decoder(self):
1537
+ return self.model
1538
+
1539
+ @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
1540
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1541
+ def forward(
1542
+ self,
1543
+ input_ids: Optional[torch.LongTensor] = None,
1544
+ pixel_values: Optional[torch.FloatTensor] = None,
1545
+ attention_mask: Optional[torch.Tensor] = None,
1546
+ position_ids: Optional[torch.LongTensor] = None,
1547
+ past_key_values: Optional[Cache] = None,
1548
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1549
+ labels: Optional[torch.LongTensor] = None,
1550
+ use_cache: Optional[bool] = None,
1551
+ output_attentions: Optional[bool] = None,
1552
+ output_hidden_states: Optional[bool] = None,
1553
+ return_dict: Optional[bool] = None,
1554
+ cache_position: Optional[torch.LongTensor] = None,
1555
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1556
+ r"""
1557
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1558
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1559
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1560
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1561
+
1562
+ Returns:
1563
+
1564
+ Example:
1565
+
1566
+ ```python
1567
+ >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
1568
+ >>> import torch
1569
+ >>> import requests
1570
+ >>> from PIL import Image
1571
+
1572
+ >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
1573
+ >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
1574
+
1575
+ >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
1576
+ >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
1577
+ >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
1578
+
1579
+ >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
1580
+
1581
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
1582
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1583
+ ```"""
1584
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1585
+ output_hidden_states = (
1586
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1587
+ )
1588
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1589
+
1590
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1591
+ outputs = self.model(
1592
+ input_ids=input_ids,
1593
+ pixel_values=pixel_values,
1594
+ attention_mask=attention_mask,
1595
+ position_ids=position_ids,
1596
+ past_key_values=past_key_values,
1597
+ inputs_embeds=inputs_embeds,
1598
+ use_cache=use_cache,
1599
+ output_attentions=output_attentions,
1600
+ output_hidden_states=output_hidden_states,
1601
+ return_dict=return_dict,
1602
+ cache_position=cache_position,
1603
+ )
1604
+
1605
+ hidden_states = outputs[0]
1606
+ logits = self.lm_head(hidden_states)
1607
+
1608
+ # Disallow image tokens which does not include special begin-image and end-image tokens
1609
+ image_tokens = self.model.vocabulary_mapping.image_tokens
1610
+ logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
1611
+
1612
+ loss = None
1613
+ if labels is not None:
1614
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1615
+ logits = logits.float()
1616
+ # Shift so that tokens < n predict n
1617
+ shift_logits = logits[..., :-1, :].contiguous()
1618
+ shift_labels = labels[..., 1:].contiguous()
1619
+ # Flatten the tokens
1620
+ loss_fct = CrossEntropyLoss()
1621
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1622
+ shift_labels = shift_labels.view(-1)
1623
+ # Enable model parallelism
1624
+ shift_labels = shift_labels.to(shift_logits.device)
1625
+ loss = loss_fct(shift_logits, shift_labels)
1626
+
1627
+ if not return_dict:
1628
+ output = (logits,) + outputs[1:]
1629
+ return (loss,) + output if loss is not None else output
1630
+
1631
+ return CausalLMOutputWithPast(
1632
+ loss=loss,
1633
+ logits=logits,
1634
+ past_key_values=outputs.past_key_values,
1635
+ hidden_states=outputs.hidden_states,
1636
+ attentions=outputs.attentions,
1637
+ )
1638
+
1639
+ def prepare_inputs_for_generation(
1640
+ self,
1641
+ input_ids,
1642
+ pixel_values=None,
1643
+ past_key_values=None,
1644
+ attention_mask=None,
1645
+ inputs_embeds=None,
1646
+ cache_position=None,
1647
+ position_ids=None,
1648
+ use_cache=True,
1649
+ **kwargs,
1650
+ ):
1651
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1652
+
1653
+ model_inputs = super().prepare_inputs_for_generation(
1654
+ input_ids,
1655
+ pixel_values=pixel_values,
1656
+ past_key_values=past_key_values,
1657
+ attention_mask=attention_mask,
1658
+ inputs_embeds=inputs_embeds,
1659
+ cache_position=cache_position,
1660
+ position_ids=position_ids,
1661
+ use_cache=use_cache,
1662
+ **kwargs,
1663
+ )
1664
+
1665
+ if cache_position[0] != 0:
1666
+ # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
1667
+ # Otherwise we need pixel values to be passed to model
1668
+ model_inputs["pixel_values"] = None
1669
+
1670
+ return model_inputs
1671
+
1672
+
1673
+ __all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]
docs/transformers/build/lib/transformers/models/chameleon/processing_chameleon.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Chameleon.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from ...feature_extraction_utils import BatchFeature
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
24
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
25
+
26
+
27
+ class ChameleonTextKwargs(TextKwargs, total=False):
28
+ return_for_text_completion: bool
29
+
30
+
31
+ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
32
+ text_kwargs: ChameleonTextKwargs
33
+ _defaults = {
34
+ "text_kwargs": {
35
+ "padding": False,
36
+ "return_for_text_completion": False,
37
+ },
38
+ "common_kwargs": {
39
+ "return_tensors": "pt",
40
+ },
41
+ }
42
+
43
+
44
+ class ChameleonProcessor(ProcessorMixin):
45
+ r"""
46
+ Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
47
+ processor.
48
+
49
+ [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
50
+ See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
51
+
52
+ Args:
53
+ image_processor ([`ChameleonImageProcessor`]):
54
+ The image processor is a required input.
55
+ tokenizer ([`LlamaTokenizerFast`]):
56
+ The tokenizer is a required input.
57
+ image_seq_length (`int`, *optional*, defaults to 1024):
58
+ Sequence length of one image embedding.
59
+ image_token (`str`, *optional*, defaults to `"<image>"`):
60
+ The special token used to indicate image in the text.
61
+ """
62
+
63
+ attributes = ["image_processor", "tokenizer"]
64
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
65
+ valid_kwargs = ["image_seq_length", "image_token"]
66
+ image_processor_class = "ChameleonImageProcessor"
67
+
68
+ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
69
+ self.image_seq_length = image_seq_length
70
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
71
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
72
+ self.image_start_token = (
73
+ tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
74
+ ) # fixed tokens for start and end, so can hardcode
75
+ self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
76
+
77
+ super().__init__(image_processor, tokenizer)
78
+
79
+ def __call__(
80
+ self,
81
+ images: Optional[ImageInput] = None,
82
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
83
+ audio=None,
84
+ videos=None,
85
+ **kwargs: Unpack[ChameleonProcessorKwargs],
86
+ ) -> BatchFeature:
87
+ """
88
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
89
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
90
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
91
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
92
+ of the above two methods for more information.
93
+
94
+ Args:
95
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
96
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
97
+ tensor. Both channels-first and channels-last formats are supported.
98
+ text (`str`, `List[str]`, `List[List[str]]`):
99
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
100
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
101
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
102
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
103
+ If set, will return tensors of a particular framework. Acceptable values are:
104
+
105
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
106
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
107
+ - `'np'`: Return NumPy `np.ndarray` objects.
108
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
109
+
110
+ Returns:
111
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
112
+
113
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
114
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
115
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
116
+ `None`).
117
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
118
+ """
119
+ # check if images and text inputs are reversed for BC
120
+ images, text = _validate_images_text_input_order(images, text)
121
+ if isinstance(text, str):
122
+ text = [text]
123
+ elif not isinstance(text, list) and not isinstance(text[0], str):
124
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
125
+ if text is None and images is None:
126
+ raise ValueError("You must provide either text or images")
127
+
128
+ output_kwargs = self._merge_kwargs(
129
+ ChameleonProcessorKwargs,
130
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
131
+ **kwargs,
132
+ )
133
+ return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
134
+
135
+ # Replace the image token with the expanded image token sequence
136
+ prompt_strings = []
137
+ one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
138
+ for sample in text:
139
+ sample = sample.replace(self.image_token, one_img_tokens)
140
+ if not return_for_text_completion:
141
+ sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
142
+ prompt_strings.append(sample)
143
+
144
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
145
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
146
+ self._check_special_mm_tokens(prompt_strings, data, modalities=["image"])
147
+
148
+ if images is not None:
149
+ data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
150
+
151
+ return BatchFeature(data=data, tensor_type=return_tensors)
152
+
153
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
154
+ def batch_decode(self, *args, **kwargs):
155
+ """
156
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
157
+ refer to the docstring of this method for more information.
158
+ """
159
+ return self.tokenizer.batch_decode(*args, **kwargs)
160
+
161
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
162
+ def decode(self, *args, **kwargs):
163
+ """
164
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
165
+ the docstring of this method for more information.
166
+ """
167
+ return self.tokenizer.decode(*args, **kwargs)
168
+
169
+ @property
170
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
171
+ def model_input_names(self):
172
+ tokenizer_input_names = self.tokenizer.model_input_names
173
+ image_processor_input_names = self.image_processor.model_input_names
174
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
175
+
176
+
177
+ __all__ = ["ChameleonProcessor"]
docs/transformers/build/lib/transformers/models/chinese_clip/configuration_chinese_clip.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Chinese-CLIP model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from ...processing_utils import ProcessorMixin
23
+ from ...utils import TensorType
24
+
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...onnx import OnnxConfig
27
+ from ...utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class ChineseCLIPTextConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate a
36
+ Chinese CLIP model according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the Chinese CLIP
38
+ [OFA-Sys/chinese-clip-vit-base-patch16](https:
39
+ //huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+
45
+ Args:
46
+ vocab_size (`int`, *optional*, defaults to 30522):
47
+ Vocabulary size of the CHINESE_CLIP model. Defines the number of different tokens that can be represented
48
+ by the `inputs_ids` passed when calling [`ChineseCLIPModel`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ max_position_embeddings (`int`, *optional*, defaults to 512):
65
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
66
+ just in case (e.g., 512 or 1024 or 2048).
67
+ type_vocab_size (`int`, *optional*, defaults to 2):
68
+ The vocabulary size of the `token_type_ids` passed when calling [`ChineseCLIPModel`].
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ initializer_factor (`float`, *optional*, defaults to 1.0):
72
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
73
+ testing).
74
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
75
+ The epsilon used by the layer normalization layers.
76
+ pad_token_id (`int`, *optional*, defaults to 0):
77
+ Padding token id.
78
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
79
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
80
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
81
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
82
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
83
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
84
+ use_cache (`bool`, *optional*, defaults to `True`):
85
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
86
+ relevant if `config.is_decoder=True`.
87
+
88
+ Example:
89
+
90
+ ```python
91
+ >>> from transformers import ChineseCLIPTextConfig, ChineseCLIPTextModel
92
+
93
+ >>> # Initializing a ChineseCLIPTextConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
94
+ >>> configuration = ChineseCLIPTextConfig()
95
+
96
+ >>> # Initializing a ChineseCLIPTextModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
97
+ >>> model = ChineseCLIPTextModel(configuration)
98
+
99
+ >>> # Accessing the model configuration
100
+ >>> configuration = model.config
101
+ ```"""
102
+
103
+ model_type = "chinese_clip_text_model"
104
+ base_config_key = "text_config"
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=30522,
109
+ hidden_size=768,
110
+ num_hidden_layers=12,
111
+ num_attention_heads=12,
112
+ intermediate_size=3072,
113
+ hidden_act="gelu",
114
+ hidden_dropout_prob=0.1,
115
+ attention_probs_dropout_prob=0.1,
116
+ max_position_embeddings=512,
117
+ type_vocab_size=2,
118
+ initializer_range=0.02,
119
+ initializer_factor=1.0,
120
+ layer_norm_eps=1e-12,
121
+ pad_token_id=0,
122
+ position_embedding_type="absolute",
123
+ use_cache=True,
124
+ **kwargs,
125
+ ):
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+ self.vocab_size = vocab_size
129
+ self.hidden_size = hidden_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+ self.hidden_act = hidden_act
133
+ self.intermediate_size = intermediate_size
134
+ self.hidden_dropout_prob = hidden_dropout_prob
135
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
136
+ self.max_position_embeddings = max_position_embeddings
137
+ self.type_vocab_size = type_vocab_size
138
+ self.initializer_range = initializer_range
139
+ self.initializer_factor = initializer_factor
140
+ self.layer_norm_eps = layer_norm_eps
141
+ self.position_embedding_type = position_embedding_type
142
+ self.use_cache = use_cache
143
+
144
+
145
+ class ChineseCLIPVisionConfig(PretrainedConfig):
146
+ r"""
147
+ This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
148
+ ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
149
+ configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
150
+ [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
151
+
152
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
153
+ documentation from [`PretrainedConfig`] for more information.
154
+
155
+
156
+ Args:
157
+ hidden_size (`int`, *optional*, defaults to 768):
158
+ Dimensionality of the encoder layers and the pooler layer.
159
+ intermediate_size (`int`, *optional*, defaults to 3072):
160
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
161
+ projection_dim (`int`, *optional*, defaults to 512):
162
+ Dimensionality of text and vision projection layers.
163
+ num_hidden_layers (`int`, *optional*, defaults to 12):
164
+ Number of hidden layers in the Transformer encoder.
165
+ num_attention_heads (`int`, *optional*, defaults to 12):
166
+ Number of attention heads for each attention layer in the Transformer encoder.
167
+ num_channels (`int`, *optional*, defaults to 3):
168
+ The number of input channels.
169
+ image_size (`int`, *optional*, defaults to 224):
170
+ The size (resolution) of each image.
171
+ patch_size (`int`, *optional*, defaults to 32):
172
+ The size (resolution) of each patch.
173
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
174
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
175
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
176
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
177
+ The epsilon used by the layer normalization layers.
178
+ attention_dropout (`float`, *optional*, defaults to 0.0):
179
+ The dropout ratio for the attention probabilities.
180
+ initializer_range (`float`, *optional*, defaults to 0.02):
181
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
182
+ initializer_factor (`float`, *optional*, defaults to 1.0):
183
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
184
+ testing).
185
+ Example:
186
+ ```python
187
+ >>> from transformers import ChineseCLIPVisionConfig, ChineseCLIPVisionModel
188
+
189
+ >>> # Initializing a ChineseCLIPVisionConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
190
+ >>> configuration = ChineseCLIPVisionConfig()
191
+
192
+ >>> # Initializing a ChineseCLIPVisionModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
193
+ >>> model = ChineseCLIPVisionModel(configuration)
194
+
195
+ >>> # Accessing the model configuration
196
+ >>> configuration = model.config
197
+ ```"""
198
+
199
+ model_type = "chinese_clip_vision_model"
200
+ base_config_key = "vision_config"
201
+
202
+ def __init__(
203
+ self,
204
+ hidden_size=768,
205
+ intermediate_size=3072,
206
+ projection_dim=512,
207
+ num_hidden_layers=12,
208
+ num_attention_heads=12,
209
+ num_channels=3,
210
+ image_size=224,
211
+ patch_size=32,
212
+ hidden_act="quick_gelu",
213
+ layer_norm_eps=1e-5,
214
+ attention_dropout=0.0,
215
+ initializer_range=0.02,
216
+ initializer_factor=1.0,
217
+ **kwargs,
218
+ ):
219
+ super().__init__(**kwargs)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.intermediate_size = intermediate_size
223
+ self.projection_dim = projection_dim
224
+ self.num_hidden_layers = num_hidden_layers
225
+ self.num_attention_heads = num_attention_heads
226
+ self.num_channels = num_channels
227
+ self.patch_size = patch_size
228
+ self.image_size = image_size
229
+ self.initializer_range = initializer_range
230
+ self.initializer_factor = initializer_factor
231
+ self.attention_dropout = attention_dropout
232
+ self.layer_norm_eps = layer_norm_eps
233
+ self.hidden_act = hidden_act
234
+
235
+
236
+ class ChineseCLIPConfig(PretrainedConfig):
237
+ r"""
238
+ [`ChineseCLIPConfig`] is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used
239
+ to instantiate Chinese-CLIP model according to the specified arguments, defining the text model and vision model
240
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
241
+ Chinese-CLIP [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16)
242
+ architecture.
243
+
244
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
245
+ documentation from [`PretrainedConfig`] for more information.
246
+
247
+ Args:
248
+ text_config (`dict`, *optional*):
249
+ Dictionary of configuration options used to initialize [`ChineseCLIPTextConfig`].
250
+ vision_config (`dict`, *optional*):
251
+ Dictionary of configuration options used to initialize [`ChineseCLIPVisionConfig`].
252
+ projection_dim (`int`, *optional*, defaults to 512):
253
+ Dimensionality of text and vision projection layers.
254
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
255
+ The initial value of the *logit_scale* parameter. Default is used as per the original ChineseCLIP
256
+ implementation.
257
+ kwargs (*optional*):
258
+ Dictionary of keyword arguments.
259
+
260
+ Example:
261
+
262
+ ```python
263
+ >>> from transformers import ChineseCLIPConfig, ChineseCLIPModel
264
+
265
+ >>> # Initializing a ChineseCLIPConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
266
+ >>> configuration = ChineseCLIPConfig()
267
+
268
+ >>> # Initializing a ChineseCLIPModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
269
+ >>> model = ChineseCLIPModel(configuration)
270
+
271
+ >>> # Accessing the model configuration
272
+ >>> configuration = model.config
273
+
274
+ >>> # We can also initialize a ChineseCLIPConfig from a ChineseCLIPTextConfig and a ChineseCLIPVisionConfig
275
+
276
+ >>> # Initializing a ChineseCLIPTextConfig and ChineseCLIPVisionConfig configuration
277
+ >>> config_text = ChineseCLIPTextConfig()
278
+ >>> config_vision = ChineseCLIPVisionConfig()
279
+
280
+ >>> config = ChineseCLIPConfig.from_text_vision_configs(config_text, config_vision)
281
+ ```"""
282
+
283
+ model_type = "chinese_clip"
284
+ sub_configs = {"text_config": ChineseCLIPTextConfig, "vision_config": ChineseCLIPVisionConfig}
285
+
286
+ def __init__(
287
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
288
+ ):
289
+ # If `_config_dict` exist, we use them for the backward compatibility.
290
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
291
+ # of confusion!).
292
+ text_config_dict = kwargs.pop("text_config_dict", None)
293
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
294
+
295
+ super().__init__(**kwargs)
296
+
297
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
298
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
299
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
300
+ if text_config_dict is not None:
301
+ if text_config is None:
302
+ text_config = {}
303
+
304
+ # This is the complete result when using `text_config_dict`.
305
+ _text_config_dict = ChineseCLIPTextConfig(**text_config_dict).to_dict()
306
+
307
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
308
+ for key, value in _text_config_dict.items():
309
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
310
+ # If specified in `text_config_dict`
311
+ if key in text_config_dict:
312
+ message = (
313
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
314
+ f'The value `text_config_dict["{key}"]` will be used instead.'
315
+ )
316
+ # If inferred from default argument values (just to be super careful)
317
+ else:
318
+ message = (
319
+ f"`text_config_dict` is provided which will be used to initialize `ChineseCLIPTextConfig`. "
320
+ f'The value `text_config["{key}"]` will be overridden.'
321
+ )
322
+ logger.info(message)
323
+
324
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
325
+ text_config.update(_text_config_dict)
326
+
327
+ if vision_config_dict is not None:
328
+ if vision_config is None:
329
+ vision_config = {}
330
+
331
+ # This is the complete result when using `vision_config_dict`.
332
+ _vision_config_dict = ChineseCLIPVisionConfig(**vision_config_dict).to_dict()
333
+ # convert keys to string instead of integer
334
+ if "id2label" in _vision_config_dict:
335
+ _vision_config_dict["id2label"] = {
336
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
337
+ }
338
+
339
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
340
+ for key, value in _vision_config_dict.items():
341
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
342
+ # If specified in `vision_config_dict`
343
+ if key in vision_config_dict:
344
+ message = (
345
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
346
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
347
+ )
348
+ # If inferred from default argument values (just to be super careful)
349
+ else:
350
+ message = (
351
+ f"`vision_config_dict` is provided which will be used to initialize "
352
+ f'`ChineseCLIPVisionConfig`. The value `vision_config["{key}"]` will be overridden.'
353
+ )
354
+ logger.info(message)
355
+
356
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
357
+ vision_config.update(_vision_config_dict)
358
+
359
+ if text_config is None:
360
+ text_config = {}
361
+ logger.info("`text_config` is `None`. Initializing the `ChineseCLIPTextConfig` with default values.")
362
+
363
+ if vision_config is None:
364
+ vision_config = {}
365
+ logger.info("`vision_config` is `None`. initializing the `ChineseCLIPVisionConfig` with default values.")
366
+
367
+ self.text_config = ChineseCLIPTextConfig(**text_config)
368
+ self.vision_config = ChineseCLIPVisionConfig(**vision_config)
369
+
370
+ self.projection_dim = projection_dim
371
+ self.logit_scale_init_value = logit_scale_init_value
372
+ self.initializer_factor = 1.0
373
+ self.initializer_range = 0.02
374
+
375
+ @classmethod
376
+ def from_text_vision_configs(
377
+ cls, text_config: ChineseCLIPTextConfig, vision_config: ChineseCLIPVisionConfig, **kwargs
378
+ ):
379
+ r"""
380
+ Instantiate a [`ChineseCLIPConfig`] (or a derived class) from Chinese-CLIP text model configuration and
381
+ Chinese-CLIP vision model configuration. Returns:
382
+ [`ChineseCLIPConfig`]: An instance of a configuration object
383
+ """
384
+
385
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
386
+
387
+
388
+ class ChineseCLIPOnnxConfig(OnnxConfig):
389
+ @property
390
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
391
+ return OrderedDict(
392
+ [
393
+ ("input_ids", {0: "batch", 1: "sequence"}),
394
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
395
+ ("attention_mask", {0: "batch", 1: "sequence"}),
396
+ ]
397
+ )
398
+
399
+ @property
400
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
401
+ return OrderedDict(
402
+ [
403
+ ("logits_per_image", {0: "batch"}),
404
+ ("logits_per_text", {0: "batch"}),
405
+ ("text_embeds", {0: "batch"}),
406
+ ("image_embeds", {0: "batch"}),
407
+ ]
408
+ )
409
+
410
+ @property
411
+ def atol_for_validation(self) -> float:
412
+ return 1e-4
413
+
414
+ def generate_dummy_inputs(
415
+ self,
416
+ processor: "ProcessorMixin",
417
+ batch_size: int = -1,
418
+ seq_length: int = -1,
419
+ framework: Optional["TensorType"] = None,
420
+ ) -> Mapping[str, Any]:
421
+ text_input_dict = super().generate_dummy_inputs(
422
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
423
+ )
424
+ image_input_dict = super().generate_dummy_inputs(
425
+ processor.image_processor, batch_size=batch_size, framework=framework
426
+ )
427
+ return {**text_input_dict, **image_input_dict}
428
+
429
+ @property
430
+ def default_onnx_opset(self) -> int:
431
+ return 14
432
+
433
+
434
+ __all__ = ["ChineseCLIPConfig", "ChineseCLIPOnnxConfig", "ChineseCLIPTextConfig", "ChineseCLIPVisionConfig"]
docs/transformers/build/lib/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+
18
+ import torch
19
+
20
+ from transformers import ChineseCLIPConfig, ChineseCLIPModel
21
+
22
+
23
+ def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
24
+ q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
25
+ q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
26
+
27
+ out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
28
+ out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
29
+
30
+ hf_attn_layer.q_proj.weight.data = q_proj
31
+ hf_attn_layer.q_proj.bias.data = q_proj_bias
32
+
33
+ hf_attn_layer.k_proj.weight.data = k_proj
34
+ hf_attn_layer.k_proj.bias.data = k_proj_bias
35
+
36
+ hf_attn_layer.v_proj.weight.data = v_proj
37
+ hf_attn_layer.v_proj.bias.data = v_proj_bias
38
+
39
+ hf_attn_layer.out_proj.weight.data = out_proj_weights
40
+ hf_attn_layer.out_proj.bias.data = out_proj_bias
41
+
42
+
43
+ def copy_mlp(hf_mlp, pt_weights, prefix):
44
+ copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
45
+ copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
46
+
47
+
48
+ def copy_linear(hf_linear, pt_weights, prefix):
49
+ hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
50
+ hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
51
+
52
+
53
+ def copy_layer(hf_layer, pt_weights, prefix):
54
+ # copy layer norms
55
+ copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
56
+ copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
57
+
58
+ # copy MLP
59
+ copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
60
+
61
+ # copy attn
62
+ copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
63
+
64
+
65
+ def copy_layers(hf_layers, pt_weights, prefix):
66
+ for layer_id, hf_layer in enumerate(hf_layers):
67
+ copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
68
+
69
+
70
+ def copy_text_model_and_projection(hf_model, pt_weights):
71
+ # copy projection
72
+ hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
73
+
74
+ # copy text encoder
75
+ for name, param in hf_model.text_model.named_parameters():
76
+ param.data = pt_weights[f"bert.{name}"].data
77
+
78
+
79
+ def copy_vision_model_and_projection(hf_model, pt_weights):
80
+ # copy projection
81
+ hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
82
+
83
+ # copy layer norms
84
+ copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
85
+ copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
86
+
87
+ # copy embeddings
88
+ hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
89
+ hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
90
+ hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
91
+
92
+ # copy encoder
93
+ copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
94
+
95
+
96
+ @torch.no_grad()
97
+ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
98
+ """
99
+ Copy/paste/tweak model's weights to transformers design.
100
+ """
101
+
102
+ assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
103
+ config = ChineseCLIPConfig.from_pretrained(config_path)
104
+
105
+ hf_model = ChineseCLIPModel(config).eval()
106
+
107
+ pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
108
+ pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
109
+
110
+ copy_text_model_and_projection(hf_model, pt_weights)
111
+ copy_vision_model_and_projection(hf_model, pt_weights)
112
+ hf_model.logit_scale.data = pt_weights["logit_scale"].data
113
+
114
+ hf_model.save_pretrained(pytorch_dump_folder_path)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "--pytorch_dump_folder_path",
121
+ default=None,
122
+ type=str,
123
+ help="Path to the output folder storing converted hf PyTorch model.",
124
+ )
125
+ parser.add_argument(
126
+ "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
127
+ )
128
+ parser.add_argument(
129
+ "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
130
+ )
131
+ args = parser.parse_args()
132
+
133
+ convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
134
+ print("The conversion is finished!")
docs/transformers/build/lib/transformers/models/chinese_clip/feature_extraction_chinese_clip.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for Chinese-CLIP."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .image_processing_chinese_clip import ChineseCLIPImageProcessor
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @requires(backends=("vision",))
28
+ class ChineseCLIPFeatureExtractor(ChineseCLIPImageProcessor):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ warnings.warn(
31
+ "The class ChineseCLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
32
+ " Please use ChineseCLIPImageProcessor instead.",
33
+ FutureWarning,
34
+ )
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = ["ChineseCLIPFeatureExtractor"]
docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Chinese-CLIP."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ convert_to_rgb,
24
+ get_resize_output_image_size,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ )
28
+ from ...image_utils import (
29
+ OPENAI_CLIP_MEAN,
30
+ OPENAI_CLIP_STD,
31
+ ChannelDimension,
32
+ ImageInput,
33
+ PILImageResampling,
34
+ infer_channel_dimension_format,
35
+ is_scaled_image,
36
+ make_list_of_images,
37
+ to_numpy_array,
38
+ valid_images,
39
+ validate_preprocess_arguments,
40
+ )
41
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
42
+
43
+
44
+ if is_vision_available():
45
+ import PIL
46
+
47
+
48
+ from ...utils.import_utils import requires
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ @requires(backends=("vision",))
55
+ class ChineseCLIPImageProcessor(BaseImageProcessor):
56
+ r"""
57
+ Constructs a Chinese-CLIP image processor.
58
+
59
+ Args:
60
+ do_resize (`bool`, *optional*, defaults to `True`):
61
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
62
+ `do_resize` in the `preprocess` method.
63
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
64
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
65
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
66
+ method.
67
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
68
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
69
+ do_center_crop (`bool`, *optional*, defaults to `True`):
70
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
71
+ `preprocess` method.
72
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
73
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
74
+ method.
75
+ do_rescale (`bool`, *optional*, defaults to `True`):
76
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
77
+ the `preprocess` method.
78
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
79
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
80
+ method.
81
+ do_normalize (`bool`, *optional*, defaults to `True`):
82
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
83
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
84
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
85
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
86
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
87
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
88
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
89
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
90
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
91
+ Whether to convert the image to RGB.
92
+ """
93
+
94
+ model_input_names = ["pixel_values"]
95
+
96
+ def __init__(
97
+ self,
98
+ do_resize: bool = True,
99
+ size: Dict[str, int] = None,
100
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
101
+ do_center_crop: bool = True,
102
+ crop_size: Dict[str, int] = None,
103
+ do_rescale: bool = True,
104
+ rescale_factor: Union[int, float] = 1 / 255,
105
+ do_normalize: bool = True,
106
+ image_mean: Optional[Union[float, List[float]]] = None,
107
+ image_std: Optional[Union[float, List[float]]] = None,
108
+ do_convert_rgb: bool = True,
109
+ **kwargs,
110
+ ) -> None:
111
+ super().__init__(**kwargs)
112
+ size = size if size is not None else {"shortest_edge": 224}
113
+ size = get_size_dict(size, default_to_square=False)
114
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
115
+ crop_size = get_size_dict(crop_size)
116
+
117
+ self.do_resize = do_resize
118
+ self.size = size
119
+ self.resample = resample
120
+ self.do_center_crop = do_center_crop
121
+ self.crop_size = crop_size
122
+ self.do_rescale = do_rescale
123
+ self.rescale_factor = rescale_factor
124
+ self.do_normalize = do_normalize
125
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
126
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
127
+ self.do_convert_rgb = do_convert_rgb
128
+
129
+ def resize(
130
+ self,
131
+ image: np.ndarray,
132
+ size: Dict[str, int],
133
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
134
+ data_format: Optional[Union[str, ChannelDimension]] = None,
135
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
136
+ **kwargs,
137
+ ) -> np.ndarray:
138
+ """
139
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
140
+ resized to keep the input aspect ratio.
141
+
142
+ Args:
143
+ image (`np.ndarray`):
144
+ Image to resize.
145
+ size (`Dict[str, int]`):
146
+ Size of the output image.
147
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
148
+ Resampling filter to use when resiizing the image.
149
+ data_format (`str` or `ChannelDimension`, *optional*):
150
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
151
+ input_data_format (`ChannelDimension` or `str`, *optional*):
152
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
153
+ image.
154
+ """
155
+ size = get_size_dict(size, default_to_square=False)
156
+ output_size = get_resize_output_image_size(
157
+ image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format
158
+ )
159
+ return resize(
160
+ image,
161
+ size=output_size,
162
+ resample=resample,
163
+ data_format=data_format,
164
+ input_data_format=input_data_format,
165
+ **kwargs,
166
+ )
167
+
168
+ @filter_out_non_signature_kwargs()
169
+ def preprocess(
170
+ self,
171
+ images: ImageInput,
172
+ do_resize: Optional[bool] = None,
173
+ size: Dict[str, int] = None,
174
+ resample: PILImageResampling = None,
175
+ do_center_crop: Optional[bool] = None,
176
+ crop_size: Optional[int] = None,
177
+ do_rescale: Optional[bool] = None,
178
+ rescale_factor: Optional[float] = None,
179
+ do_normalize: Optional[bool] = None,
180
+ image_mean: Optional[Union[float, List[float]]] = None,
181
+ image_std: Optional[Union[float, List[float]]] = None,
182
+ do_convert_rgb: Optional[bool] = None,
183
+ return_tensors: Optional[Union[str, TensorType]] = None,
184
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
185
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
186
+ ) -> PIL.Image.Image:
187
+ """
188
+ Preprocess an image or batch of images.
189
+
190
+ Args:
191
+ images (`ImageInput`):
192
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
193
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
194
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
195
+ Whether to resize the image.
196
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
197
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
198
+ the longest edge resized to keep the input aspect ratio.
199
+ resample (`int`, *optional*, defaults to `self.resample`):
200
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
201
+ has an effect if `do_resize` is set to `True`.
202
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
203
+ Whether to center crop the image.
204
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
205
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
206
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
207
+ Whether to rescale the image.
208
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
209
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
210
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
211
+ Whether to normalize the image.
212
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
213
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
214
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
215
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
216
+ `True`.
217
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
218
+ Whether to convert the image to RGB.
219
+ return_tensors (`str` or `TensorType`, *optional*):
220
+ The type of tensors to return. Can be one of:
221
+ - Unset: Return a list of `np.ndarray`.
222
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
223
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
224
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
225
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
226
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
227
+ The channel dimension format for the output image. Can be one of:
228
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
229
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
230
+ - Unset: Use the channel dimension format of the input image.
231
+ input_data_format (`ChannelDimension` or `str`, *optional*):
232
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
233
+ from the input image. Can be one of:
234
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
235
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
236
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
237
+ """
238
+
239
+ do_resize = do_resize if do_resize is not None else self.do_resize
240
+ size = size if size is not None else self.size
241
+ size = get_size_dict(size, default_to_square=False)
242
+ resample = resample if resample is not None else self.resample
243
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
244
+ crop_size = crop_size if crop_size is not None else self.crop_size
245
+ crop_size = get_size_dict(crop_size)
246
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
247
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
248
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
249
+ image_mean = image_mean if image_mean is not None else self.image_mean
250
+ image_std = image_std if image_std is not None else self.image_std
251
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
252
+
253
+ images = make_list_of_images(images)
254
+
255
+ if not valid_images(images):
256
+ raise ValueError(
257
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
258
+ "torch.Tensor, tf.Tensor or jax.ndarray."
259
+ )
260
+ validate_preprocess_arguments(
261
+ do_rescale=do_rescale,
262
+ rescale_factor=rescale_factor,
263
+ do_normalize=do_normalize,
264
+ image_mean=image_mean,
265
+ image_std=image_std,
266
+ do_center_crop=do_center_crop,
267
+ crop_size=crop_size,
268
+ do_resize=do_resize,
269
+ size=size,
270
+ resample=resample,
271
+ )
272
+ if do_convert_rgb:
273
+ images = [convert_to_rgb(image) for image in images]
274
+
275
+ # All transformations expect numpy arrays.
276
+ images = [to_numpy_array(image) for image in images]
277
+
278
+ if do_rescale and is_scaled_image(images[0]):
279
+ logger.warning_once(
280
+ "It looks like you are trying to rescale already rescaled images. If the input"
281
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
282
+ )
283
+
284
+ if input_data_format is None:
285
+ # We assume that all images have the same channel dimension format.
286
+ input_data_format = infer_channel_dimension_format(images[0])
287
+
288
+ all_images = []
289
+ for image in images:
290
+ if do_resize:
291
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
292
+
293
+ if do_center_crop:
294
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
295
+
296
+ if do_rescale:
297
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
298
+
299
+ if do_normalize:
300
+ image = self.normalize(
301
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
302
+ )
303
+
304
+ all_images.append(image)
305
+ images = [
306
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
307
+ for image in all_images
308
+ ]
309
+
310
+ data = {"pixel_values": images}
311
+ return BatchFeature(data=data, tensor_type=return_tensors)
312
+
313
+
314
+ __all__ = ["ChineseCLIPImageProcessor"]
docs/transformers/build/lib/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for Chinese-CLIP."""
16
+
17
+ from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
18
+ from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
19
+ from ...utils import add_start_docstrings
20
+
21
+
22
+ @add_start_docstrings(
23
+ "Constructs a fast ChineseCLIP image processor.",
24
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
25
+ )
26
+ class ChineseCLIPImageProcessorFast(BaseImageProcessorFast):
27
+ resample = PILImageResampling.BICUBIC
28
+ image_mean = OPENAI_CLIP_MEAN
29
+ image_std = OPENAI_CLIP_STD
30
+ size = {"shortest_edge": 224}
31
+ default_to_square = False
32
+ crop_size = {"height": 224, "width": 224}
33
+ do_resize = True
34
+ do_center_crop = True
35
+ do_rescale = True
36
+ do_normalize = True
37
+ do_convert_rgb = True
38
+
39
+
40
+ __all__ = ["ChineseCLIPImageProcessorFast"]
docs/transformers/build/lib/transformers/models/chinese_clip/modeling_chinese_clip.py ADDED
@@ -0,0 +1,1630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Chinese-CLIP model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPooling,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ )
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ torch_int,
42
+ )
43
+ from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CHECKPOINT_FOR_DOC = "OFA-Sys/chinese-clip-vit-base-patch16"
49
+ _CONFIG_FOR_DOC = "ChineseCLIPConfig"
50
+
51
+
52
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
53
+ # Copied from transformers.models.clip.modeling_clip.contrastive_loss
54
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
55
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
56
+
57
+
58
+ def chinese_clip_loss(similarity: torch.Tensor) -> torch.Tensor:
59
+ caption_loss = contrastive_loss(similarity)
60
+ image_loss = contrastive_loss(similarity.t())
61
+ return (caption_loss + image_loss) / 2.0
62
+
63
+
64
+ @dataclass
65
+ class ChineseCLIPOutput(ModelOutput):
66
+ """
67
+ Args:
68
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
69
+ Contrastive loss for image-text similarity.
70
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
71
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
72
+ similarity scores.
73
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
74
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
75
+ similarity scores.
76
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
77
+ The text embeddings obtained by applying the projection layer to the pooled output of
78
+ [`ChineseCLIPTextModel`].
79
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
80
+ The image embeddings obtained by applying the projection layer to the pooled output of
81
+ [`ChineseCLIPVisionModel`].
82
+ text_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`):
83
+ The output of the [`ChineseCLIPTextModel`].
84
+ vision_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`):
85
+ The output of the [`ChineseCLIPVisionModel`].
86
+ """
87
+
88
+ loss: Optional[torch.FloatTensor] = None
89
+ logits_per_image: Optional[torch.FloatTensor] = None
90
+ logits_per_text: Optional[torch.FloatTensor] = None
91
+ text_embeds: Optional[torch.FloatTensor] = None
92
+ image_embeds: Optional[torch.FloatTensor] = None
93
+ text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
94
+ vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
95
+
96
+ def to_tuple(self) -> Tuple[Any]:
97
+ return tuple(
98
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
99
+ for k in self.keys()
100
+ )
101
+
102
+
103
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
104
+ class ChineseCLIPTextEmbeddings(nn.Module):
105
+ """Construct the embeddings from word, position and token_type embeddings."""
106
+
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
110
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
111
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
112
+
113
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
114
+ # any TensorFlow checkpoint file
115
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
116
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
117
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
118
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
119
+ self.register_buffer(
120
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
121
+ )
122
+ self.register_buffer(
123
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: Optional[torch.LongTensor] = None,
129
+ token_type_ids: Optional[torch.LongTensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ inputs_embeds: Optional[torch.FloatTensor] = None,
132
+ past_key_values_length: int = 0,
133
+ ) -> torch.Tensor:
134
+ if input_ids is not None:
135
+ input_shape = input_ids.size()
136
+ else:
137
+ input_shape = inputs_embeds.size()[:-1]
138
+
139
+ seq_length = input_shape[1]
140
+
141
+ if position_ids is None:
142
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
143
+
144
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
145
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
146
+ # issue #5664
147
+ if token_type_ids is None:
148
+ if hasattr(self, "token_type_ids"):
149
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
150
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
151
+ token_type_ids = buffered_token_type_ids_expanded
152
+ else:
153
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
154
+
155
+ if inputs_embeds is None:
156
+ inputs_embeds = self.word_embeddings(input_ids)
157
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
158
+
159
+ embeddings = inputs_embeds + token_type_embeddings
160
+ if self.position_embedding_type == "absolute":
161
+ position_embeddings = self.position_embeddings(position_ids)
162
+ embeddings += position_embeddings
163
+ embeddings = self.LayerNorm(embeddings)
164
+ embeddings = self.dropout(embeddings)
165
+ return embeddings
166
+
167
+
168
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->ChineseCLIP
169
+ class ChineseCLIPVisionEmbeddings(nn.Module):
170
+ def __init__(self, config: ChineseCLIPVisionConfig):
171
+ super().__init__()
172
+ self.config = config
173
+ self.embed_dim = config.hidden_size
174
+ self.image_size = config.image_size
175
+ self.patch_size = config.patch_size
176
+
177
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
178
+
179
+ self.patch_embedding = nn.Conv2d(
180
+ in_channels=config.num_channels,
181
+ out_channels=self.embed_dim,
182
+ kernel_size=self.patch_size,
183
+ stride=self.patch_size,
184
+ bias=False,
185
+ )
186
+
187
+ self.num_patches = (self.image_size // self.patch_size) ** 2
188
+ self.num_positions = self.num_patches + 1
189
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
190
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
191
+
192
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
193
+ """
194
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
195
+ images. This method is also adapted to support torch.jit tracing.
196
+
197
+ Adapted from:
198
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
199
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
200
+ """
201
+
202
+ num_patches = embeddings.shape[1] - 1
203
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
204
+ num_positions = position_embedding.shape[1] - 1
205
+
206
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
207
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
208
+ return self.position_embedding(self.position_ids)
209
+
210
+ class_pos_embed = position_embedding[:, :1]
211
+ patch_pos_embed = position_embedding[:, 1:]
212
+
213
+ dim = embeddings.shape[-1]
214
+
215
+ new_height = height // self.patch_size
216
+ new_width = width // self.patch_size
217
+
218
+ sqrt_num_positions = torch_int(num_positions**0.5)
219
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
220
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
221
+
222
+ patch_pos_embed = nn.functional.interpolate(
223
+ patch_pos_embed,
224
+ size=(new_height, new_width),
225
+ mode="bicubic",
226
+ align_corners=False,
227
+ )
228
+
229
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
230
+
231
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
232
+
233
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
234
+ batch_size, _, height, width = pixel_values.shape
235
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
236
+ raise ValueError(
237
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
238
+ )
239
+ target_dtype = self.patch_embedding.weight.dtype
240
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
241
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
242
+
243
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
244
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
245
+ if interpolate_pos_encoding:
246
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
247
+ else:
248
+ embeddings = embeddings + self.position_embedding(self.position_ids)
249
+ return embeddings
250
+
251
+
252
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
253
+ class ChineseCLIPTextSelfAttention(nn.Module):
254
+ def __init__(self, config, position_embedding_type=None):
255
+ super().__init__()
256
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
257
+ raise ValueError(
258
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
259
+ f"heads ({config.num_attention_heads})"
260
+ )
261
+
262
+ self.num_attention_heads = config.num_attention_heads
263
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
264
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
265
+
266
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
267
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
268
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
269
+
270
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
271
+ self.position_embedding_type = position_embedding_type or getattr(
272
+ config, "position_embedding_type", "absolute"
273
+ )
274
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
275
+ self.max_position_embeddings = config.max_position_embeddings
276
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
277
+
278
+ self.is_decoder = config.is_decoder
279
+
280
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
281
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
282
+ x = x.view(new_x_shape)
283
+ return x.permute(0, 2, 1, 3)
284
+
285
+ def forward(
286
+ self,
287
+ hidden_states: torch.Tensor,
288
+ attention_mask: Optional[torch.FloatTensor] = None,
289
+ head_mask: Optional[torch.FloatTensor] = None,
290
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
291
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
292
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
293
+ output_attentions: Optional[bool] = False,
294
+ ) -> Tuple[torch.Tensor]:
295
+ mixed_query_layer = self.query(hidden_states)
296
+
297
+ # If this is instantiated as a cross-attention module, the keys
298
+ # and values come from an encoder; the attention mask needs to be
299
+ # such that the encoder's padding tokens are not attended to.
300
+ is_cross_attention = encoder_hidden_states is not None
301
+
302
+ if is_cross_attention and past_key_value is not None:
303
+ # reuse k,v, cross_attentions
304
+ key_layer = past_key_value[0]
305
+ value_layer = past_key_value[1]
306
+ attention_mask = encoder_attention_mask
307
+ elif is_cross_attention:
308
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
310
+ attention_mask = encoder_attention_mask
311
+ elif past_key_value is not None:
312
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
313
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
314
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
315
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
316
+ else:
317
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
318
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
319
+
320
+ query_layer = self.transpose_for_scores(mixed_query_layer)
321
+
322
+ use_cache = past_key_value is not None
323
+ if self.is_decoder:
324
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
325
+ # Further calls to cross_attention layer can then reuse all cross-attention
326
+ # key/value_states (first "if" case)
327
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
328
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
329
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
330
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
331
+ past_key_value = (key_layer, value_layer)
332
+
333
+ # Take the dot product between "query" and "key" to get the raw attention scores.
334
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
335
+
336
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
337
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
338
+ if use_cache:
339
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
340
+ -1, 1
341
+ )
342
+ else:
343
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
344
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
345
+ distance = position_ids_l - position_ids_r
346
+
347
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
348
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
349
+
350
+ if self.position_embedding_type == "relative_key":
351
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
352
+ attention_scores = attention_scores + relative_position_scores
353
+ elif self.position_embedding_type == "relative_key_query":
354
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
355
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
356
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
357
+
358
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
359
+ if attention_mask is not None:
360
+ # Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
361
+ attention_scores = attention_scores + attention_mask
362
+
363
+ # Normalize the attention scores to probabilities.
364
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
365
+
366
+ # This is actually dropping out entire tokens to attend to, which might
367
+ # seem a bit unusual, but is taken from the original Transformer paper.
368
+ attention_probs = self.dropout(attention_probs)
369
+
370
+ # Mask heads if we want to
371
+ if head_mask is not None:
372
+ attention_probs = attention_probs * head_mask
373
+
374
+ context_layer = torch.matmul(attention_probs, value_layer)
375
+
376
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
377
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
378
+ context_layer = context_layer.view(new_context_layer_shape)
379
+
380
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
381
+
382
+ if self.is_decoder:
383
+ outputs = outputs + (past_key_value,)
384
+ return outputs
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->ChineseCLIPText
388
+ class ChineseCLIPTextSelfOutput(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
392
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
393
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
394
+
395
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
396
+ hidden_states = self.dense(hidden_states)
397
+ hidden_states = self.dropout(hidden_states)
398
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
399
+ return hidden_states
400
+
401
+
402
+ CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
403
+ "eager": ChineseCLIPTextSelfAttention,
404
+ }
405
+
406
+
407
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
408
+ class ChineseCLIPTextAttention(nn.Module):
409
+ def __init__(self, config, position_embedding_type=None):
410
+ super().__init__()
411
+ self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
412
+ config, position_embedding_type=position_embedding_type
413
+ )
414
+ self.output = ChineseCLIPTextSelfOutput(config)
415
+ self.pruned_heads = set()
416
+
417
+ def prune_heads(self, heads):
418
+ if len(heads) == 0:
419
+ return
420
+ heads, index = find_pruneable_heads_and_indices(
421
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
422
+ )
423
+
424
+ # Prune linear layers
425
+ self.self.query = prune_linear_layer(self.self.query, index)
426
+ self.self.key = prune_linear_layer(self.self.key, index)
427
+ self.self.value = prune_linear_layer(self.self.value, index)
428
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
429
+
430
+ # Update hyper params and store pruned heads
431
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
432
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
433
+ self.pruned_heads = self.pruned_heads.union(heads)
434
+
435
+ def forward(
436
+ self,
437
+ hidden_states: torch.Tensor,
438
+ attention_mask: Optional[torch.FloatTensor] = None,
439
+ head_mask: Optional[torch.FloatTensor] = None,
440
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
441
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
442
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
443
+ output_attentions: Optional[bool] = False,
444
+ ) -> Tuple[torch.Tensor]:
445
+ self_outputs = self.self(
446
+ hidden_states,
447
+ attention_mask,
448
+ head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ )
454
+ attention_output = self.output(self_outputs[0], hidden_states)
455
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
456
+ return outputs
457
+
458
+
459
+ class ChineseCLIPVisionAttention(nn.Module):
460
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
461
+
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.config = config
465
+ self.embed_dim = config.hidden_size
466
+ self.num_heads = config.num_attention_heads
467
+ self.head_dim = self.embed_dim // self.num_heads
468
+ if self.head_dim * self.num_heads != self.embed_dim:
469
+ raise ValueError(
470
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
471
+ f" {self.num_heads})."
472
+ )
473
+ self.scale = self.head_dim**-0.5
474
+ self.dropout = config.attention_dropout
475
+
476
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
477
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
478
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
479
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
480
+
481
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
482
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.Tensor,
487
+ output_attentions: Optional[bool] = False,
488
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
489
+ """Input shape: Batch x Time x Channel"""
490
+
491
+ bsz, tgt_len, embed_dim = hidden_states.size()
492
+
493
+ # get query proj
494
+ query_states = self.q_proj(hidden_states) * self.scale
495
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
496
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
497
+
498
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
499
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
500
+ key_states = key_states.view(*proj_shape)
501
+ value_states = value_states.view(*proj_shape)
502
+
503
+ src_len = key_states.size(1)
504
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
505
+
506
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
507
+ raise ValueError(
508
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
509
+ f" {attn_weights.size()}"
510
+ )
511
+
512
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
513
+
514
+ if output_attentions:
515
+ # this operation is a bit akward, but it's required to
516
+ # make sure that attn_weights keeps its gradient.
517
+ # In order to do so, attn_weights have to reshaped
518
+ # twice and have to be reused in the following
519
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
520
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
521
+ else:
522
+ attn_weights_reshaped = None
523
+
524
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
525
+
526
+ attn_output = torch.bmm(attn_probs, value_states)
527
+
528
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
529
+ raise ValueError(
530
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
531
+ f" {attn_output.size()}"
532
+ )
533
+
534
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
535
+ attn_output = attn_output.transpose(1, 2)
536
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
537
+
538
+ attn_output = self.out_proj(attn_output)
539
+
540
+ return attn_output, attn_weights_reshaped
541
+
542
+
543
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
544
+ class ChineseCLIPTextIntermediate(nn.Module):
545
+ def __init__(self, config):
546
+ super().__init__()
547
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
548
+ if isinstance(config.hidden_act, str):
549
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
550
+ else:
551
+ self.intermediate_act_fn = config.hidden_act
552
+
553
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
554
+ hidden_states = self.dense(hidden_states)
555
+ hidden_states = self.intermediate_act_fn(hidden_states)
556
+ return hidden_states
557
+
558
+
559
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->ChineseCLIPText
560
+ class ChineseCLIPTextOutput(nn.Module):
561
+ def __init__(self, config):
562
+ super().__init__()
563
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
564
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
565
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
566
+
567
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
568
+ hidden_states = self.dense(hidden_states)
569
+ hidden_states = self.dropout(hidden_states)
570
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
571
+ return hidden_states
572
+
573
+
574
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->ChineseCLIPVision
575
+ class ChineseCLIPVisionMLP(nn.Module):
576
+ def __init__(self, config):
577
+ super().__init__()
578
+ self.config = config
579
+ self.activation_fn = ACT2FN[config.hidden_act]
580
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
581
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
582
+
583
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
584
+ hidden_states = self.fc1(hidden_states)
585
+ hidden_states = self.activation_fn(hidden_states)
586
+ hidden_states = self.fc2(hidden_states)
587
+ return hidden_states
588
+
589
+
590
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
591
+ class ChineseCLIPTextLayer(nn.Module):
592
+ def __init__(self, config):
593
+ super().__init__()
594
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
595
+ self.seq_len_dim = 1
596
+ self.attention = ChineseCLIPTextAttention(config)
597
+ self.is_decoder = config.is_decoder
598
+ self.add_cross_attention = config.add_cross_attention
599
+ if self.add_cross_attention:
600
+ if not self.is_decoder:
601
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
602
+ self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
603
+ self.intermediate = ChineseCLIPTextIntermediate(config)
604
+ self.output = ChineseCLIPTextOutput(config)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.Tensor,
609
+ attention_mask: Optional[torch.FloatTensor] = None,
610
+ head_mask: Optional[torch.FloatTensor] = None,
611
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
612
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
613
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
614
+ output_attentions: Optional[bool] = False,
615
+ ) -> Tuple[torch.Tensor]:
616
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
617
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
618
+ self_attention_outputs = self.attention(
619
+ hidden_states,
620
+ attention_mask,
621
+ head_mask,
622
+ output_attentions=output_attentions,
623
+ past_key_value=self_attn_past_key_value,
624
+ )
625
+ attention_output = self_attention_outputs[0]
626
+
627
+ # if decoder, the last output is tuple of self-attn cache
628
+ if self.is_decoder:
629
+ outputs = self_attention_outputs[1:-1]
630
+ present_key_value = self_attention_outputs[-1]
631
+ else:
632
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
633
+
634
+ cross_attn_present_key_value = None
635
+ if self.is_decoder and encoder_hidden_states is not None:
636
+ if not hasattr(self, "crossattention"):
637
+ raise ValueError(
638
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
639
+ " by setting `config.add_cross_attention=True`"
640
+ )
641
+
642
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
643
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
644
+ cross_attention_outputs = self.crossattention(
645
+ attention_output,
646
+ attention_mask,
647
+ head_mask,
648
+ encoder_hidden_states,
649
+ encoder_attention_mask,
650
+ cross_attn_past_key_value,
651
+ output_attentions,
652
+ )
653
+ attention_output = cross_attention_outputs[0]
654
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
655
+
656
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
657
+ cross_attn_present_key_value = cross_attention_outputs[-1]
658
+ present_key_value = present_key_value + cross_attn_present_key_value
659
+
660
+ layer_output = apply_chunking_to_forward(
661
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
662
+ )
663
+ outputs = (layer_output,) + outputs
664
+
665
+ # if decoder, return the attn key/values as the last output
666
+ if self.is_decoder:
667
+ outputs = outputs + (present_key_value,)
668
+
669
+ return outputs
670
+
671
+ def feed_forward_chunk(self, attention_output):
672
+ intermediate_output = self.intermediate(attention_output)
673
+ layer_output = self.output(intermediate_output, attention_output)
674
+ return layer_output
675
+
676
+
677
+ class ChineseCLIPVisionLayer(nn.Module):
678
+ def __init__(self, config: ChineseCLIPConfig):
679
+ super().__init__()
680
+ self.embed_dim = config.hidden_size
681
+ self.self_attn = ChineseCLIPVisionAttention(config)
682
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
683
+ self.mlp = ChineseCLIPVisionMLP(config)
684
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
685
+
686
+ def forward(
687
+ self,
688
+ hidden_states: torch.Tensor,
689
+ output_attentions: Optional[bool] = False,
690
+ ) -> Tuple[torch.FloatTensor]:
691
+ """
692
+ Args:
693
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
694
+ output_attentions (`bool`, *optional*):
695
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
696
+ returned tensors for more detail.
697
+ """
698
+ residual = hidden_states
699
+
700
+ hidden_states = self.layer_norm1(hidden_states)
701
+ hidden_states, attn_weights = self.self_attn(
702
+ hidden_states=hidden_states,
703
+ output_attentions=output_attentions,
704
+ )
705
+ hidden_states = residual + hidden_states
706
+
707
+ residual = hidden_states
708
+ hidden_states = self.layer_norm2(hidden_states)
709
+ hidden_states = self.mlp(hidden_states)
710
+ hidden_states = residual + hidden_states
711
+
712
+ outputs = (hidden_states,)
713
+
714
+ if output_attentions:
715
+ outputs += (attn_weights,)
716
+
717
+ return outputs
718
+
719
+
720
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->ChineseCLIPText
721
+ class ChineseCLIPTextPooler(nn.Module):
722
+ def __init__(self, config):
723
+ super().__init__()
724
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
725
+ self.activation = nn.Tanh()
726
+
727
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
728
+ # We "pool" the model by simply taking the hidden state corresponding
729
+ # to the first token.
730
+ first_token_tensor = hidden_states[:, 0]
731
+ pooled_output = self.dense(first_token_tensor)
732
+ pooled_output = self.activation(pooled_output)
733
+ return pooled_output
734
+
735
+
736
+ class ChineseCLIPPreTrainedModel(PreTrainedModel):
737
+ """
738
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
739
+ models.
740
+ """
741
+
742
+ config_class = ChineseCLIPConfig
743
+ base_model_prefix = "chinese_clip"
744
+ supports_gradient_checkpointing = True
745
+
746
+ def _init_weights(self, module):
747
+ """Initialize the weights"""
748
+ factor = self.config.initializer_factor
749
+ if isinstance(module, ChineseCLIPVisionEmbeddings):
750
+ factor = self.config.initializer_factor
751
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
752
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
753
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
754
+ elif isinstance(module, ChineseCLIPTextEmbeddings):
755
+ nn.init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range)
756
+ nn.init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range)
757
+ nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range)
758
+ for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]:
759
+ if embedding.padding_idx is not None:
760
+ embedding.weight.data[embedding.padding_idx].zero_()
761
+ elif isinstance(module, ChineseCLIPVisionAttention):
762
+ factor = self.config.initializer_factor
763
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
764
+ out_proj_std = (module.embed_dim**-0.5) * factor
765
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
766
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
767
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
768
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
769
+ elif isinstance(module, ChineseCLIPVisionMLP):
770
+ factor = self.config.initializer_factor
771
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
772
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
773
+ nn.init.normal_(module.fc1.weight, std=fc_std)
774
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
775
+ elif isinstance(module, ChineseCLIPModel):
776
+ nn.init.normal_(
777
+ module.text_projection.weight,
778
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
779
+ )
780
+ nn.init.normal_(
781
+ module.visual_projection.weight,
782
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
783
+ )
784
+
785
+ if isinstance(module, nn.LayerNorm):
786
+ module.bias.data.zero_()
787
+ module.weight.data.fill_(1.0)
788
+ if isinstance(module, nn.Linear):
789
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
790
+ if module.bias is not None:
791
+ module.bias.data.zero_()
792
+
793
+
794
+ CHINESE_CLIP_START_DOCSTRING = r"""
795
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
796
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
797
+ behavior.
798
+
799
+ Parameters:
800
+ config ([`ChineseCLIPConfig`]): Model configuration class with all the parameters of the model.
801
+ Initializing with a config file does not load the weights associated with the model, only the
802
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
803
+ """
804
+
805
+ CHINESE_CLIP_TEXT_INPUTS_DOCSTRING = r"""
806
+ Args:
807
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
808
+ Indices of input sequence tokens in the vocabulary.
809
+
810
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
811
+ [`PreTrainedTokenizer.__call__`] for details.
812
+
813
+ [What are input IDs?](../glossary#input-ids)
814
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
815
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
816
+
817
+ - 1 for tokens that are **not masked**,
818
+ - 0 for tokens that are **masked**.
819
+
820
+ [What are attention masks?](../glossary#attention-mask)
821
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
822
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
823
+ 1]`:
824
+
825
+ - 0 corresponds to a *sentence A* token,
826
+ - 1 corresponds to a *sentence B* token.
827
+
828
+ [What are token type IDs?](../glossary#token-type-ids)
829
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
830
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
831
+ config.max_position_embeddings - 1]`.
832
+
833
+ [What are position IDs?](../glossary#position-ids)
834
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
835
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
836
+
837
+ - 1 indicates the head is **not masked**,
838
+ - 0 indicates the head is **masked**.
839
+
840
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
841
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
842
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
843
+ model's internal embedding lookup matrix.
844
+ output_attentions (`bool`, *optional*):
845
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
846
+ tensors for more detail.
847
+ output_hidden_states (`bool`, *optional*):
848
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
849
+ more detail.
850
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
851
+ Whether to interpolate the pre-trained position encodings.
852
+ return_dict (`bool`, *optional*):
853
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
854
+ """
855
+
856
+ CHINESE_CLIP_VISION_INPUTS_DOCSTRING = r"""
857
+ Args:
858
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
859
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
860
+ [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details.
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
863
+ tensors for more detail.
864
+ output_hidden_states (`bool`, *optional*):
865
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
866
+ more detail.
867
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
868
+ Whether to interpolate the pre-trained position encodings.
869
+ return_dict (`bool`, *optional*):
870
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
871
+ """
872
+
873
+ CHINESE_CLIP_INPUTS_DOCSTRING = r"""
874
+ Args:
875
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
876
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
877
+ it.
878
+
879
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
880
+ [`PreTrainedTokenizer.__call__`] for details.
881
+
882
+ [What are input IDs?](../glossary#input-ids)
883
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
884
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
885
+
886
+ - 1 for tokens that are **not masked**,
887
+ - 0 for tokens that are **masked**.
888
+
889
+ [What are attention masks?](../glossary#attention-mask)
890
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
891
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
892
+ 1]`:
893
+
894
+ - 0 corresponds to a *sentence A* token,
895
+ - 1 corresponds to a *sentence B* token.
896
+
897
+ [What are token type IDs?](../glossary#token-type-ids)
898
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
899
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
900
+ config.max_position_embeddings - 1]`.
901
+
902
+ [What are position IDs?](../glossary#position-ids)
903
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
904
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
905
+ [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details.
906
+ return_loss (`bool`, *optional*):
907
+ Whether or not to return the contrastive loss.
908
+ output_attentions (`bool`, *optional*):
909
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
910
+ tensors for more detail.
911
+ output_hidden_states (`bool`, *optional*):
912
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
913
+ more detail.
914
+ return_dict (`bool`, *optional*):
915
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
916
+ """
917
+
918
+
919
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
920
+ class ChineseCLIPTextEncoder(nn.Module):
921
+ def __init__(self, config):
922
+ super().__init__()
923
+ self.config = config
924
+ self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
925
+ self.gradient_checkpointing = False
926
+
927
+ def forward(
928
+ self,
929
+ hidden_states: torch.Tensor,
930
+ attention_mask: Optional[torch.FloatTensor] = None,
931
+ head_mask: Optional[torch.FloatTensor] = None,
932
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
933
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
934
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
935
+ use_cache: Optional[bool] = None,
936
+ output_attentions: Optional[bool] = False,
937
+ output_hidden_states: Optional[bool] = False,
938
+ return_dict: Optional[bool] = True,
939
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
940
+ all_hidden_states = () if output_hidden_states else None
941
+ all_self_attentions = () if output_attentions else None
942
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
943
+
944
+ if self.gradient_checkpointing and self.training:
945
+ if use_cache:
946
+ logger.warning_once(
947
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
948
+ )
949
+ use_cache = False
950
+
951
+ next_decoder_cache = () if use_cache else None
952
+ for i, layer_module in enumerate(self.layer):
953
+ if output_hidden_states:
954
+ all_hidden_states = all_hidden_states + (hidden_states,)
955
+
956
+ layer_head_mask = head_mask[i] if head_mask is not None else None
957
+ past_key_value = past_key_values[i] if past_key_values is not None else None
958
+
959
+ if self.gradient_checkpointing and self.training:
960
+ layer_outputs = self._gradient_checkpointing_func(
961
+ layer_module.__call__,
962
+ hidden_states,
963
+ attention_mask,
964
+ layer_head_mask,
965
+ encoder_hidden_states,
966
+ encoder_attention_mask,
967
+ past_key_value,
968
+ output_attentions,
969
+ )
970
+ else:
971
+ layer_outputs = layer_module(
972
+ hidden_states,
973
+ attention_mask,
974
+ layer_head_mask,
975
+ encoder_hidden_states,
976
+ encoder_attention_mask,
977
+ past_key_value,
978
+ output_attentions,
979
+ )
980
+
981
+ hidden_states = layer_outputs[0]
982
+ if use_cache:
983
+ next_decoder_cache += (layer_outputs[-1],)
984
+ if output_attentions:
985
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
986
+ if self.config.add_cross_attention:
987
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
988
+
989
+ if output_hidden_states:
990
+ all_hidden_states = all_hidden_states + (hidden_states,)
991
+
992
+ if not return_dict:
993
+ return tuple(
994
+ v
995
+ for v in [
996
+ hidden_states,
997
+ next_decoder_cache,
998
+ all_hidden_states,
999
+ all_self_attentions,
1000
+ all_cross_attentions,
1001
+ ]
1002
+ if v is not None
1003
+ )
1004
+ return BaseModelOutputWithPastAndCrossAttentions(
1005
+ last_hidden_state=hidden_states,
1006
+ past_key_values=next_decoder_cache,
1007
+ hidden_states=all_hidden_states,
1008
+ attentions=all_self_attentions,
1009
+ cross_attentions=all_cross_attentions,
1010
+ )
1011
+
1012
+
1013
+ class ChineseCLIPVisionEncoder(nn.Module):
1014
+ """
1015
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
1016
+ [`ChineseCLIPVisionEncoderLayer`].
1017
+
1018
+ Args:
1019
+ config: ChineseCLIPConfig
1020
+ """
1021
+
1022
+ def __init__(self, config: ChineseCLIPConfig):
1023
+ super().__init__()
1024
+ self.config = config
1025
+ self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
1026
+ self.gradient_checkpointing = False
1027
+
1028
+ def forward(
1029
+ self,
1030
+ inputs_embeds,
1031
+ output_attentions: Optional[bool] = None,
1032
+ output_hidden_states: Optional[bool] = None,
1033
+ return_dict: Optional[bool] = None,
1034
+ ) -> Union[Tuple, BaseModelOutput]:
1035
+ r"""
1036
+ Args:
1037
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1038
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1039
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1040
+ than the model's internal embedding lookup matrix.
1041
+ output_attentions (`bool`, *optional*):
1042
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1043
+ returned tensors for more detail.
1044
+ output_hidden_states (`bool`, *optional*):
1045
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1046
+ for more detail.
1047
+ return_dict (`bool`, *optional*):
1048
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1049
+ """
1050
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1051
+ output_hidden_states = (
1052
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1053
+ )
1054
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1055
+
1056
+ encoder_states = () if output_hidden_states else None
1057
+ all_attentions = () if output_attentions else None
1058
+
1059
+ hidden_states = inputs_embeds
1060
+ for idx, encoder_layer in enumerate(self.layers):
1061
+ if output_hidden_states:
1062
+ encoder_states = encoder_states + (hidden_states,)
1063
+ if self.gradient_checkpointing and self.training:
1064
+ layer_outputs = self._gradient_checkpointing_func(
1065
+ encoder_layer.__call__,
1066
+ hidden_states,
1067
+ output_attentions,
1068
+ )
1069
+ else:
1070
+ layer_outputs = encoder_layer(
1071
+ hidden_states,
1072
+ output_attentions=output_attentions,
1073
+ )
1074
+
1075
+ hidden_states = layer_outputs[0]
1076
+
1077
+ if output_attentions:
1078
+ all_attentions = all_attentions + (layer_outputs[1],)
1079
+
1080
+ if output_hidden_states:
1081
+ encoder_states = encoder_states + (hidden_states,)
1082
+
1083
+ if not return_dict:
1084
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1085
+ return BaseModelOutput(
1086
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1087
+ )
1088
+
1089
+
1090
+ class ChineseCLIPVisionTransformer(nn.Module):
1091
+ def __init__(self, config: ChineseCLIPVisionConfig):
1092
+ super().__init__()
1093
+ self.config = config
1094
+ embed_dim = config.hidden_size
1095
+
1096
+ self.embeddings = ChineseCLIPVisionEmbeddings(config)
1097
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1098
+ self.encoder = ChineseCLIPVisionEncoder(config)
1099
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1100
+
1101
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1102
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig)
1103
+ def forward(
1104
+ self,
1105
+ pixel_values: Optional[torch.FloatTensor] = None,
1106
+ output_attentions: Optional[bool] = None,
1107
+ output_hidden_states: Optional[bool] = None,
1108
+ interpolate_pos_encoding: bool = False,
1109
+ return_dict: Optional[bool] = None,
1110
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1111
+ r"""
1112
+ Returns:
1113
+ """
1114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1115
+ output_hidden_states = (
1116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1117
+ )
1118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1119
+
1120
+ if pixel_values is None:
1121
+ raise ValueError("You have to specify pixel_values")
1122
+
1123
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1124
+ hidden_states = self.pre_layrnorm(hidden_states)
1125
+
1126
+ encoder_outputs = self.encoder(
1127
+ inputs_embeds=hidden_states,
1128
+ output_attentions=output_attentions,
1129
+ output_hidden_states=output_hidden_states,
1130
+ return_dict=return_dict,
1131
+ )
1132
+
1133
+ last_hidden_state = encoder_outputs[0]
1134
+ pooled_output = last_hidden_state[:, 0, :]
1135
+ pooled_output = self.post_layernorm(pooled_output)
1136
+
1137
+ if not return_dict:
1138
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1139
+
1140
+ return BaseModelOutputWithPooling(
1141
+ last_hidden_state=last_hidden_state,
1142
+ pooler_output=pooled_output,
1143
+ hidden_states=encoder_outputs.hidden_states,
1144
+ attentions=encoder_outputs.attentions,
1145
+ )
1146
+
1147
+
1148
+ @add_start_docstrings(
1149
+ "The text model from CHINESE_CLIP without any head or projection on top.",
1150
+ CHINESE_CLIP_START_DOCSTRING,
1151
+ )
1152
+ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
1153
+ """
1154
+
1155
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1156
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
1157
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
1158
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
1159
+
1160
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
1161
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
1162
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
1163
+ """
1164
+
1165
+ config_class = ChineseCLIPTextConfig
1166
+ _no_split_modules = ["ChineseCLIPTextEmbeddings"]
1167
+
1168
+ def __init__(self, config, add_pooling_layer=True):
1169
+ super().__init__(config)
1170
+ self.config = config
1171
+
1172
+ self.embeddings = ChineseCLIPTextEmbeddings(config)
1173
+ self.encoder = ChineseCLIPTextEncoder(config)
1174
+
1175
+ self.pooler = ChineseCLIPTextPooler(config) if add_pooling_layer else None
1176
+
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+ def get_input_embeddings(self):
1181
+ return self.embeddings.word_embeddings
1182
+
1183
+ def set_input_embeddings(self, value):
1184
+ self.embeddings.word_embeddings = value
1185
+
1186
+ def _prune_heads(self, heads_to_prune):
1187
+ """
1188
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1189
+ class PreTrainedModel
1190
+ """
1191
+ for layer, heads in heads_to_prune.items():
1192
+ self.encoder.layer[layer].attention.prune_heads(heads)
1193
+
1194
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1195
+ @add_code_sample_docstrings(
1196
+ checkpoint=_CHECKPOINT_FOR_DOC,
1197
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
1198
+ config_class=_CONFIG_FOR_DOC,
1199
+ )
1200
+ def forward(
1201
+ self,
1202
+ input_ids: Optional[torch.Tensor] = None,
1203
+ attention_mask: Optional[torch.Tensor] = None,
1204
+ token_type_ids: Optional[torch.Tensor] = None,
1205
+ position_ids: Optional[torch.Tensor] = None,
1206
+ head_mask: Optional[torch.Tensor] = None,
1207
+ inputs_embeds: Optional[torch.Tensor] = None,
1208
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1209
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1210
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1211
+ use_cache: Optional[bool] = None,
1212
+ output_attentions: Optional[bool] = None,
1213
+ output_hidden_states: Optional[bool] = None,
1214
+ return_dict: Optional[bool] = None,
1215
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1216
+ r"""
1217
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1218
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1219
+ the model is configured as a decoder.
1220
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1221
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1222
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1223
+
1224
+ - 1 for tokens that are **not masked**,
1225
+ - 0 for tokens that are **masked**.
1226
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1227
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1228
+
1229
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1230
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1231
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1232
+ use_cache (`bool`, *optional*):
1233
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1234
+ `past_key_values`).
1235
+ """
1236
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1237
+ output_hidden_states = (
1238
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1239
+ )
1240
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1241
+
1242
+ if self.config.is_decoder:
1243
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1244
+ else:
1245
+ use_cache = False
1246
+
1247
+ if input_ids is not None and inputs_embeds is not None:
1248
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1249
+ elif input_ids is not None:
1250
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1251
+ input_shape = input_ids.size()
1252
+ elif inputs_embeds is not None:
1253
+ input_shape = inputs_embeds.size()[:-1]
1254
+ else:
1255
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1256
+
1257
+ batch_size, seq_length = input_shape
1258
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1259
+
1260
+ # past_key_values_length
1261
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1262
+
1263
+ if attention_mask is None:
1264
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1265
+
1266
+ if token_type_ids is None:
1267
+ if hasattr(self.embeddings, "token_type_ids"):
1268
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1269
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1270
+ token_type_ids = buffered_token_type_ids_expanded
1271
+ else:
1272
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1273
+
1274
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1275
+ # ourselves in which case we just need to make it broadcastable to all heads.
1276
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1277
+
1278
+ # If a 2D or 3D attention mask is provided for the cross-attention
1279
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1280
+ if self.config.is_decoder and encoder_hidden_states is not None:
1281
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1282
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1283
+ if encoder_attention_mask is None:
1284
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1285
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1286
+ else:
1287
+ encoder_extended_attention_mask = None
1288
+
1289
+ # Prepare head mask if needed
1290
+ # 1.0 in head_mask indicate we keep the head
1291
+ # attention_probs has shape bsz x n_heads x N x N
1292
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1293
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1294
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1295
+
1296
+ embedding_output = self.embeddings(
1297
+ input_ids=input_ids,
1298
+ position_ids=position_ids,
1299
+ token_type_ids=token_type_ids,
1300
+ inputs_embeds=inputs_embeds,
1301
+ past_key_values_length=past_key_values_length,
1302
+ )
1303
+ encoder_outputs = self.encoder(
1304
+ embedding_output,
1305
+ attention_mask=extended_attention_mask,
1306
+ head_mask=head_mask,
1307
+ encoder_hidden_states=encoder_hidden_states,
1308
+ encoder_attention_mask=encoder_extended_attention_mask,
1309
+ past_key_values=past_key_values,
1310
+ use_cache=use_cache,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ )
1315
+ sequence_output = encoder_outputs[0]
1316
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1317
+
1318
+ if not return_dict:
1319
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1320
+
1321
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1322
+ last_hidden_state=sequence_output,
1323
+ pooler_output=pooled_output,
1324
+ past_key_values=encoder_outputs.past_key_values,
1325
+ hidden_states=encoder_outputs.hidden_states,
1326
+ attentions=encoder_outputs.attentions,
1327
+ cross_attentions=encoder_outputs.cross_attentions,
1328
+ )
1329
+
1330
+
1331
+ @add_start_docstrings(
1332
+ """The vision model from CHINESE_CLIP without any head or projection on top.""",
1333
+ CHINESE_CLIP_START_DOCSTRING,
1334
+ )
1335
+ class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
1336
+ config_class = ChineseCLIPVisionConfig
1337
+ main_input_name = "pixel_values"
1338
+ _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]
1339
+
1340
+ def __init__(self, config: ChineseCLIPVisionConfig):
1341
+ super().__init__(config)
1342
+ self.vision_model = ChineseCLIPVisionTransformer(config)
1343
+ # Initialize weights and apply final processing
1344
+ self.post_init()
1345
+
1346
+ def get_input_embeddings(self) -> nn.Module:
1347
+ return self.vision_model.embeddings.patch_embedding
1348
+
1349
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1350
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig)
1351
+ def forward(
1352
+ self,
1353
+ pixel_values: Optional[torch.FloatTensor] = None,
1354
+ output_attentions: Optional[bool] = None,
1355
+ output_hidden_states: Optional[bool] = None,
1356
+ interpolate_pos_encoding: bool = False,
1357
+ return_dict: Optional[bool] = None,
1358
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1359
+ r"""
1360
+ Returns:
1361
+
1362
+ Examples:
1363
+
1364
+ ```python
1365
+ >>> from PIL import Image
1366
+ >>> import requests
1367
+ >>> from transformers import CLIPProcessor, ChineseCLIPVisionModel
1368
+
1369
+ >>> model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1370
+ >>> processor = CLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1371
+
1372
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1373
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1374
+
1375
+ >>> inputs = processor(images=image, return_tensors="pt")
1376
+
1377
+ >>> outputs = model(**inputs)
1378
+ >>> last_hidden_state = outputs.last_hidden_state
1379
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1380
+ ```"""
1381
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1382
+
1383
+ return self.vision_model(
1384
+ pixel_values=pixel_values,
1385
+ output_attentions=output_attentions,
1386
+ output_hidden_states=output_hidden_states,
1387
+ interpolate_pos_encoding=interpolate_pos_encoding,
1388
+ return_dict=return_dict,
1389
+ )
1390
+
1391
+
1392
+ @add_start_docstrings(CHINESE_CLIP_START_DOCSTRING)
1393
+ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
1394
+ config_class = ChineseCLIPConfig
1395
+
1396
+ def __init__(self, config: ChineseCLIPConfig):
1397
+ super().__init__(config)
1398
+
1399
+ if not isinstance(config.text_config, ChineseCLIPTextConfig):
1400
+ raise TypeError(
1401
+ "config.text_config is expected to be of type ChineseCLIPTextConfig but is of type"
1402
+ f" {type(config.text_config)}."
1403
+ )
1404
+
1405
+ if not isinstance(config.vision_config, ChineseCLIPVisionConfig):
1406
+ raise TypeError(
1407
+ "config.vision_config is expected to be of type ChineseCLIPVisionConfig but is of type"
1408
+ f" {type(config.vision_config)}."
1409
+ )
1410
+
1411
+ text_config = config.text_config
1412
+ vision_config = config.vision_config
1413
+
1414
+ self.projection_dim = config.projection_dim
1415
+ self.text_embed_dim = text_config.hidden_size
1416
+ self.vision_embed_dim = vision_config.hidden_size
1417
+
1418
+ self.text_model = ChineseCLIPTextModel(text_config, add_pooling_layer=False)
1419
+ self.vision_model = ChineseCLIPVisionTransformer(vision_config)
1420
+
1421
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1422
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
1423
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1424
+
1425
+ # Initialize weights and apply final processing
1426
+ self.post_init()
1427
+
1428
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_TEXT_INPUTS_DOCSTRING)
1429
+ def get_text_features(
1430
+ self,
1431
+ input_ids: Optional[torch.Tensor] = None,
1432
+ attention_mask: Optional[torch.Tensor] = None,
1433
+ token_type_ids: Optional[torch.Tensor] = None,
1434
+ position_ids: Optional[torch.Tensor] = None,
1435
+ output_attentions: Optional[bool] = None,
1436
+ output_hidden_states: Optional[bool] = None,
1437
+ return_dict: Optional[bool] = None,
1438
+ ) -> torch.FloatTensor:
1439
+ r"""
1440
+ Returns:
1441
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1442
+ applying the projection layer to the final [CLS] hidden state of Text-Transformer.
1443
+
1444
+ Examples:
1445
+
1446
+ ```python
1447
+ >>> from transformers import AutoTokenizer, ChineseCLIPModel
1448
+
1449
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1450
+ >>> tokenizer = AutoTokenizer.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1451
+
1452
+ >>> inputs = tokenizer(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], padding=True, return_tensors="pt")
1453
+ >>> text_features = model.get_text_features(**inputs)
1454
+ >>> text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
1455
+ ```"""
1456
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1457
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1458
+ output_hidden_states = (
1459
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1460
+ )
1461
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1462
+
1463
+ text_outputs = self.text_model(
1464
+ input_ids=input_ids,
1465
+ attention_mask=attention_mask,
1466
+ token_type_ids=token_type_ids,
1467
+ position_ids=position_ids,
1468
+ output_attentions=output_attentions,
1469
+ output_hidden_states=output_hidden_states,
1470
+ return_dict=return_dict,
1471
+ )
1472
+
1473
+ pooled_output = text_outputs[0][:, 0, :]
1474
+ text_features = self.text_projection(pooled_output)
1475
+
1476
+ return text_features
1477
+
1478
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1479
+ def get_image_features(
1480
+ self,
1481
+ pixel_values: Optional[torch.FloatTensor] = None,
1482
+ output_attentions: Optional[bool] = None,
1483
+ output_hidden_states: Optional[bool] = None,
1484
+ interpolate_pos_encoding: bool = False,
1485
+ return_dict: Optional[bool] = None,
1486
+ ) -> torch.FloatTensor:
1487
+ r"""
1488
+ Returns:
1489
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1490
+ applying the projection layer to the final [CLS] hidden state of Vision-Transformer.
1491
+
1492
+ Examples:
1493
+
1494
+ ```python
1495
+ >>> from PIL import Image
1496
+ >>> import requests
1497
+ >>> from transformers import AutoProcessor, ChineseCLIPModel
1498
+
1499
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1500
+ >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1501
+
1502
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1503
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1504
+
1505
+ >>> inputs = processor(images=image, return_tensors="pt")
1506
+
1507
+ >>> image_features = model.get_image_features(**inputs)
1508
+ >>> image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
1509
+ ```"""
1510
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1511
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1512
+ output_hidden_states = (
1513
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1514
+ )
1515
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1516
+
1517
+ vision_outputs = self.vision_model(
1518
+ pixel_values=pixel_values,
1519
+ output_attentions=output_attentions,
1520
+ output_hidden_states=output_hidden_states,
1521
+ interpolate_pos_encoding=interpolate_pos_encoding,
1522
+ return_dict=return_dict,
1523
+ )
1524
+
1525
+ pooled_output = vision_outputs[1] # pooled_output
1526
+ image_features = self.visual_projection(pooled_output)
1527
+
1528
+ return image_features
1529
+
1530
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING)
1531
+ @replace_return_docstrings(output_type=ChineseCLIPOutput, config_class=ChineseCLIPConfig)
1532
+ def forward(
1533
+ self,
1534
+ input_ids: Optional[torch.LongTensor] = None,
1535
+ pixel_values: Optional[torch.FloatTensor] = None,
1536
+ attention_mask: Optional[torch.Tensor] = None,
1537
+ token_type_ids: Optional[torch.Tensor] = None,
1538
+ position_ids: Optional[torch.LongTensor] = None,
1539
+ return_loss: Optional[bool] = None,
1540
+ output_attentions: Optional[bool] = None,
1541
+ output_hidden_states: Optional[bool] = None,
1542
+ interpolate_pos_encoding: bool = False,
1543
+ return_dict: Optional[bool] = None,
1544
+ ) -> Union[Tuple, ChineseCLIPOutput]:
1545
+ r"""
1546
+ Returns:
1547
+
1548
+ Examples:
1549
+
1550
+ ```python
1551
+ >>> from PIL import Image
1552
+ >>> import requests
1553
+ >>> from transformers import AutoProcessor, ChineseCLIPModel
1554
+
1555
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1556
+ >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1557
+
1558
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1559
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1560
+
1561
+ >>> inputs = processor(text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, return_tensors="pt", padding=True)
1562
+
1563
+ >>> outputs = model(**inputs)
1564
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1565
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1566
+ ```"""
1567
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1568
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1569
+ output_hidden_states = (
1570
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1571
+ )
1572
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1573
+
1574
+ vision_outputs = self.vision_model(
1575
+ pixel_values=pixel_values,
1576
+ output_attentions=output_attentions,
1577
+ output_hidden_states=output_hidden_states,
1578
+ interpolate_pos_encoding=interpolate_pos_encoding,
1579
+ return_dict=return_dict,
1580
+ )
1581
+
1582
+ text_outputs = self.text_model(
1583
+ input_ids=input_ids,
1584
+ attention_mask=attention_mask,
1585
+ token_type_ids=token_type_ids,
1586
+ position_ids=position_ids,
1587
+ output_attentions=output_attentions,
1588
+ output_hidden_states=output_hidden_states,
1589
+ return_dict=return_dict,
1590
+ )
1591
+
1592
+ image_embeds = vision_outputs[1]
1593
+ image_embeds = self.visual_projection(image_embeds)
1594
+
1595
+ text_embeds = text_outputs[0][:, 0, :]
1596
+ text_embeds = self.text_projection(text_embeds)
1597
+
1598
+ # normalized features
1599
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1600
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1601
+
1602
+ # cosine similarity as logits
1603
+ logit_scale = self.logit_scale.exp()
1604
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1605
+ logits_per_image = logits_per_text.t()
1606
+
1607
+ loss = None
1608
+ if return_loss:
1609
+ loss = chinese_clip_loss(logits_per_text)
1610
+
1611
+ if not return_dict:
1612
+ # fix the None pooled_output of text_outputs to conform with dict_output
1613
+ pooled_output = text_outputs[1]
1614
+ if pooled_output is None:
1615
+ text_outputs = (text_outputs[0],) + text_outputs[2:]
1616
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1617
+ return ((loss,) + output) if loss is not None else output
1618
+
1619
+ return ChineseCLIPOutput(
1620
+ loss=loss,
1621
+ logits_per_image=logits_per_image,
1622
+ logits_per_text=logits_per_text,
1623
+ text_embeds=text_embeds,
1624
+ image_embeds=image_embeds,
1625
+ text_model_output=text_outputs,
1626
+ vision_model_output=vision_outputs,
1627
+ )
1628
+
1629
+
1630
+ __all__ = ["ChineseCLIPModel", "ChineseCLIPPreTrainedModel", "ChineseCLIPTextModel", "ChineseCLIPVisionModel"]
docs/transformers/build/lib/transformers/models/chinese_clip/processing_chinese_clip.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for Chinese-CLIP
17
+ """
18
+
19
+ import warnings
20
+ from typing import List, Union
21
+
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
24
+ from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
25
+
26
+
27
+ class ChineseClipProcessorKwargs(ProcessingKwargs, total=False):
28
+ _defaults = {}
29
+
30
+
31
+ class ChineseCLIPProcessor(ProcessorMixin):
32
+ r"""
33
+ Constructs a Chinese-CLIP processor which wraps a Chinese-CLIP image processor and a Chinese-CLIP tokenizer into a
34
+ single processor.
35
+
36
+ [`ChineseCLIPProcessor`] offers all the functionalities of [`ChineseCLIPImageProcessor`] and [`BertTokenizerFast`].
37
+ See the [`~ChineseCLIPProcessor.__call__`] and [`~ChineseCLIPProcessor.decode`] for more information.
38
+
39
+ Args:
40
+ image_processor ([`ChineseCLIPImageProcessor`], *optional*):
41
+ The image processor is a required input.
42
+ tokenizer ([`BertTokenizerFast`], *optional*):
43
+ The tokenizer is a required input.
44
+ """
45
+
46
+ attributes = ["image_processor", "tokenizer"]
47
+ image_processor_class = ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")
48
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
49
+
50
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
51
+ feature_extractor = None
52
+ if "feature_extractor" in kwargs:
53
+ warnings.warn(
54
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
55
+ " instead.",
56
+ FutureWarning,
57
+ )
58
+ feature_extractor = kwargs.pop("feature_extractor")
59
+
60
+ image_processor = image_processor if image_processor is not None else feature_extractor
61
+ if image_processor is None:
62
+ raise ValueError("You need to specify an `image_processor`.")
63
+ if tokenizer is None:
64
+ raise ValueError("You need to specify a `tokenizer`.")
65
+
66
+ super().__init__(image_processor, tokenizer)
67
+ self.current_processor = self.image_processor
68
+
69
+ def __call__(
70
+ self,
71
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
72
+ images: ImageInput = None,
73
+ audio=None,
74
+ videos=None,
75
+ **kwargs: Unpack[ChineseClipProcessorKwargs],
76
+ ) -> BatchEncoding:
77
+ """
78
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
79
+ and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
80
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
81
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
82
+ of the above two methods for more information.
83
+
84
+ Args:
85
+ text (`str`, `List[str]`, `List[List[str]]`):
86
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
87
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
88
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
89
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
90
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
91
+ tensor. Both channels-first and channels-last formats are supported.
92
+
93
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
94
+ If set, will return tensors of a particular framework. Acceptable values are:
95
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
96
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
97
+ - `'np'`: Return NumPy `np.ndarray` objects.
98
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
99
+ Returns:
100
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
101
+
102
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
103
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
104
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
105
+ `None`).
106
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
107
+ """
108
+
109
+ if text is None and images is None:
110
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
111
+ output_kwargs = self._merge_kwargs(
112
+ ChineseClipProcessorKwargs,
113
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
114
+ **kwargs,
115
+ )
116
+
117
+ if text is not None:
118
+ encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
119
+ if images is not None:
120
+ image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
121
+
122
+ # BC for explicit return_tensors
123
+ if "return_tensors" in output_kwargs["common_kwargs"]:
124
+ return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
125
+
126
+ if text is not None and images is not None:
127
+ encoding["pixel_values"] = image_features.pixel_values
128
+ return encoding
129
+ elif text is not None:
130
+ return encoding
131
+ else:
132
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
133
+
134
+ def batch_decode(self, *args, **kwargs):
135
+ """
136
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
137
+ refer to the docstring of this method for more information.
138
+ """
139
+ return self.tokenizer.batch_decode(*args, **kwargs)
140
+
141
+ def decode(self, *args, **kwargs):
142
+ """
143
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
144
+ the docstring of this method for more information.
145
+ """
146
+ return self.tokenizer.decode(*args, **kwargs)
147
+
148
+ @property
149
+ def model_input_names(self):
150
+ tokenizer_input_names = self.tokenizer.model_input_names
151
+ image_processor_input_names = self.image_processor.model_input_names
152
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
153
+
154
+ @property
155
+ def feature_extractor_class(self):
156
+ warnings.warn(
157
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
158
+ FutureWarning,
159
+ )
160
+ return self.image_processor_class
161
+
162
+
163
+ __all__ = ["ChineseCLIPProcessor"]
docs/transformers/build/lib/transformers/models/clap/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_clap import *
22
+ from .feature_extraction_clap import *
23
+ from .modeling_clap import *
24
+ from .processing_clap import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/clap/configuration_clap.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CLAP model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ClapTextConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`ClapTextModel`]. It is used to instantiate a CLAP
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the CLAP
29
+ [calp-hsat-fused](https://huggingface.co/laion/clap-hsat-fused) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30522):
37
+ Vocabulary size of the CLAP model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`ClapTextModel`].
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"relu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"relu"`,
49
+ `"relu"`, `"silu"` and `"relu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 512):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`ClapTextModel`].
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
60
+ The epsilon used by the layer normalization layers.
61
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
62
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
63
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
64
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
65
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
66
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
67
+ is_decoder (`bool`, *optional*, defaults to `False`):
68
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
73
+ The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`,
74
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
75
+ projection_dim (`int`, *optional*, defaults to 512)
76
+ Dimension of the projection head of the `ClapTextModelWithProjection`.
77
+
78
+ Examples:
79
+
80
+ ```python
81
+ >>> from transformers import ClapTextConfig, ClapTextModel
82
+
83
+ >>> # Initializing a CLAP text configuration
84
+ >>> configuration = ClapTextConfig()
85
+
86
+ >>> # Initializing a model (with random weights) from the configuration
87
+ >>> model = ClapTextModel(configuration)
88
+
89
+ >>> # Accessing the model configuration
90
+ >>> configuration = model.config
91
+ ```"""
92
+
93
+ model_type = "clap_text_model"
94
+ base_config_key = "text_config"
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size=50265,
99
+ hidden_size=768,
100
+ num_hidden_layers=12,
101
+ num_attention_heads=12,
102
+ intermediate_size=3072,
103
+ hidden_act="gelu",
104
+ hidden_dropout_prob=0.1,
105
+ attention_probs_dropout_prob=0.1,
106
+ max_position_embeddings=514,
107
+ type_vocab_size=1,
108
+ initializer_factor=1.0,
109
+ layer_norm_eps=1e-12,
110
+ projection_dim=512,
111
+ pad_token_id=1,
112
+ bos_token_id=0,
113
+ eos_token_id=2,
114
+ position_embedding_type="absolute",
115
+ use_cache=True,
116
+ projection_hidden_act="relu",
117
+ **kwargs,
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.hidden_act = hidden_act
126
+ self.intermediate_size = intermediate_size
127
+ self.hidden_dropout_prob = hidden_dropout_prob
128
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.type_vocab_size = type_vocab_size
131
+ self.initializer_factor = initializer_factor
132
+ self.layer_norm_eps = layer_norm_eps
133
+ self.position_embedding_type = position_embedding_type
134
+ self.use_cache = use_cache
135
+ self.projection_hidden_act = projection_hidden_act
136
+ self.projection_dim = projection_dim
137
+
138
+
139
+ class ClapAudioConfig(PretrainedConfig):
140
+ r"""
141
+ This is the configuration class to store the configuration of a [`ClapAudioModel`]. It is used to instantiate a
142
+ CLAP audio encoder according to the specified arguments, defining the model architecture. Instantiating a
143
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the CLAP
144
+ [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture.
145
+
146
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
147
+ documentation from [`PretrainedConfig`] for more information.
148
+
149
+ Args:
150
+ window_size (`int`, *optional*, defaults to 8):
151
+ Image size of the spectrogram
152
+ num_mel_bins (`int`, *optional*, defaults to 64):
153
+ Number of mel features used per frames. Should correspond to the value used in the `ClapProcessor` class.
154
+ spec_size (`int`, *optional*, defaults to 256):
155
+ Desired input size of the spectrogram that the model supports. It can be different from the output of the
156
+ `ClapFeatureExtractor`, in which case the input features will be resized. Corresponds to the `image_size`
157
+ of the audio models.
158
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
159
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
160
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
161
+ patch_size (`int`, *optional*, defaults to 4):
162
+ Patch size for the audio spectrogram
163
+ patch_stride (`list`, *optional*, defaults to `[4, 4]`):
164
+ Patch stride for the audio spectrogram
165
+ num_classes (`int`, *optional*, defaults to 527):
166
+ Number of classes used for the head training
167
+ hidden_size (`int`, *optional*, defaults to 768):
168
+ Hidden size of the output of the audio encoder. Correspond to the dimension of the penultimate layer's
169
+ output,which is sent to the projection MLP layer.
170
+ projection_dim (`int`, *optional*, defaults to 512):
171
+ Hidden size of the projection layer.
172
+ depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`):
173
+ Depths used for the Swin Layers of the audio model
174
+ num_attention_heads (`list`, *optional*, defaults to `[4, 8, 16, 32]`):
175
+ Number of attention heads used for the Swin Layers of the audio model
176
+ enable_fusion (`bool`, *optional*, defaults to `False`):
177
+ Whether or not to enable patch fusion. This is the main contribution of the authors, and should give the
178
+ best results.
179
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
180
+ The dropout probability for all fully connected layers in the encoder.
181
+ fusion_type (`[type]`, *optional*):
182
+ Fusion type used for the patch fusion.
183
+ patch_embed_input_channels (`int`, *optional*, defaults to 1):
184
+ Number of channels used for the input spectrogram
185
+ flatten_patch_embeds (`bool`, *optional*, defaults to `True`):
186
+ Whether or not to flatten the patch embeddings
187
+ patch_embeds_hidden_size (`int`, *optional*, defaults to 96):
188
+ Hidden size of the patch embeddings. It is used as the number of output channels.
189
+ enable_patch_layer_norm (`bool`, *optional*, defaults to `True`):
190
+ Whether or not to enable layer normalization for the patch embeddings
191
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
192
+ Drop path rate for the patch fusion
193
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
194
+ The dropout ratio for the attention probabilities.
195
+ qkv_bias (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to add a bias to the query, key, value projections.
197
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
198
+ Ratio of the mlp hidden dim to embedding dim.
199
+ aff_block_r (`int`, *optional*, defaults to 4):
200
+ downsize_ratio used in the AudioFF block
201
+ num_hidden_layers (`int`, *optional*, defaults to 4):
202
+ Number of hidden layers in the Transformer encoder.
203
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
204
+ The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`,
205
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
206
+ layer_norm_eps (`[type]`, *optional*, defaults to 1e-05):
207
+ The epsilon used by the layer normalization layers.
208
+ initializer_factor (`float`, *optional*, defaults to 1.0):
209
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
210
+ testing).
211
+
212
+ Example:
213
+
214
+ ```python
215
+ >>> from transformers import ClapAudioConfig, ClapAudioModel
216
+
217
+ >>> # Initializing a ClapAudioConfig with laion/clap-htsat-fused style configuration
218
+ >>> configuration = ClapAudioConfig()
219
+
220
+ >>> # Initializing a ClapAudioModel (with random weights) from the laion/clap-htsat-fused style configuration
221
+ >>> model = ClapAudioModel(configuration)
222
+
223
+ >>> # Accessing the model configuration
224
+ >>> configuration = model.config
225
+ ```"""
226
+
227
+ model_type = "clap_audio_model"
228
+ base_config_key = "audio_config"
229
+
230
+ def __init__(
231
+ self,
232
+ window_size=8,
233
+ num_mel_bins=64,
234
+ spec_size=256,
235
+ hidden_act="gelu",
236
+ patch_size=4,
237
+ patch_stride=[4, 4],
238
+ num_classes=527,
239
+ hidden_size=768,
240
+ projection_dim=512,
241
+ depths=[2, 2, 6, 2],
242
+ num_attention_heads=[4, 8, 16, 32],
243
+ enable_fusion=False,
244
+ hidden_dropout_prob=0.1,
245
+ fusion_type=None,
246
+ patch_embed_input_channels=1,
247
+ flatten_patch_embeds=True,
248
+ patch_embeds_hidden_size=96,
249
+ enable_patch_layer_norm=True,
250
+ drop_path_rate=0.0,
251
+ attention_probs_dropout_prob=0.0,
252
+ qkv_bias=True,
253
+ mlp_ratio=4.0,
254
+ aff_block_r=4,
255
+ num_hidden_layers=4,
256
+ projection_hidden_act="relu",
257
+ layer_norm_eps=1e-5,
258
+ initializer_factor=1.0,
259
+ **kwargs,
260
+ ):
261
+ super().__init__(**kwargs)
262
+ self.window_size = window_size
263
+ self.num_mel_bins = num_mel_bins
264
+ self.spec_size = spec_size
265
+ self.patch_size = patch_size
266
+ self.patch_stride = patch_stride
267
+ self.num_classes = num_classes
268
+ self.hidden_size = hidden_size
269
+ self.depths = depths
270
+ self.num_hidden_layers = num_hidden_layers
271
+ self.num_attention_heads = num_attention_heads
272
+ self.window_size = window_size
273
+ self.enable_fusion = enable_fusion
274
+ self.fusion_type = fusion_type
275
+ self.hidden_act = hidden_act
276
+ self.hidden_dropout_prob = hidden_dropout_prob
277
+ self.projection_dim = projection_dim
278
+ self.flatten_patch_embeds = flatten_patch_embeds
279
+ self.patch_embeds_hidden_size = patch_embeds_hidden_size
280
+ self.enable_patch_layer_norm = enable_patch_layer_norm
281
+ self.drop_path_rate = drop_path_rate
282
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
283
+ self.qkv_bias = qkv_bias
284
+ self.mlp_ratio = mlp_ratio
285
+ self.patch_embed_input_channels = patch_embed_input_channels
286
+ self.aff_block_r = aff_block_r
287
+ self.layer_norm_eps = layer_norm_eps
288
+ self.initializer_factor = initializer_factor
289
+ self.projection_hidden_act = projection_hidden_act
290
+
291
+
292
+ class ClapConfig(PretrainedConfig):
293
+ r"""
294
+ [`ClapConfig`] is the configuration class to store the configuration of a [`ClapModel`]. It is used to instantiate
295
+ a CLAP model according to the specified arguments, defining the text model and audio model configs. Instantiating a
296
+ configuration with the defaults will yield a similar configuration to that of the CLAP
297
+ [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture.
298
+
299
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
300
+ documentation from [`PretrainedConfig`] for more information.
301
+
302
+ Args:
303
+ text_config (`dict`, *optional*):
304
+ Dictionary of configuration options used to initialize [`ClapTextConfig`].
305
+ audio_config (`dict`, *optional*):
306
+ Dictionary of configuration options used to initialize [`ClapAudioConfig`].
307
+ logit_scale_init_value (`float`, *optional*, defaults to 14.29):
308
+ The initial value of the *logit_scale* parameter. Default is used as per the original CLAP implementation.
309
+ projection_dim (`int`, *optional*, defaults to 512):
310
+ Dimensionality of text and audio projection layers.
311
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
312
+ Activation function for the projection layers.
313
+ initializer_factor (`float`, *optional*, defaults to 1.0):
314
+ Factor to scale the initialization of the model weights.
315
+ kwargs (*optional*):
316
+ Dictionary of keyword arguments.
317
+
318
+ Example:
319
+
320
+ ```python
321
+ >>> from transformers import ClapConfig, ClapModel
322
+
323
+ >>> # Initializing a ClapConfig with laion-ai/base style configuration
324
+ >>> configuration = ClapConfig()
325
+
326
+ >>> # Initializing a ClapModel (with random weights) from the laion-ai/base style configuration
327
+ >>> model = ClapModel(configuration)
328
+
329
+ >>> # Accessing the model configuration
330
+ >>> configuration = model.config
331
+
332
+ >>> # We can also initialize a ClapConfig from a ClapTextConfig and a ClapAudioConfig
333
+ >>> from transformers import ClapTextConfig, ClapAudioConfig
334
+
335
+ >>> # Initializing a ClapText and ClapAudioConfig configuration
336
+ >>> config_text = ClapTextConfig()
337
+ >>> config_audio = ClapAudioConfig()
338
+
339
+ >>> config = ClapConfig.from_text_audio_configs(config_text, config_audio)
340
+ ```"""
341
+
342
+ model_type = "clap"
343
+ sub_configs = {"text_config": ClapTextConfig, "audio_config": ClapAudioConfig}
344
+
345
+ def __init__(
346
+ self,
347
+ text_config=None,
348
+ audio_config=None,
349
+ logit_scale_init_value=(1 / 0.07),
350
+ projection_dim=512,
351
+ projection_hidden_act="relu",
352
+ initializer_factor=1.0,
353
+ **kwargs,
354
+ ):
355
+ super().__init__(**kwargs)
356
+
357
+ if text_config is None:
358
+ text_config = {}
359
+ logger.info("text_config is None. Initializing the ClapTextConfig with default values.")
360
+
361
+ if audio_config is None:
362
+ audio_config = {}
363
+ logger.info("audio_config is None. initializing the ClapAudioConfig with default values.")
364
+
365
+ self.text_config = ClapTextConfig(**text_config)
366
+ self.audio_config = ClapAudioConfig(**audio_config)
367
+ self.text_config.projection_dim = projection_dim
368
+ self.audio_config.projection_dim = projection_dim
369
+
370
+ self.text_config.projection_hidden_act = projection_hidden_act
371
+ self.audio_config.projection_hidden_act = projection_hidden_act
372
+
373
+ self.projection_dim = projection_dim
374
+ self.projection_hidden_act = projection_hidden_act
375
+ self.hidden_size = self.text_config.hidden_size
376
+
377
+ self.logit_scale_init_value = logit_scale_init_value
378
+ self.initializer_factor = initializer_factor
379
+ self.num_hidden_layers = self.text_config.num_hidden_layers + len(self.audio_config.depths)
380
+
381
+ @classmethod
382
+ def from_text_audio_configs(cls, text_config: ClapTextConfig, audio_config: ClapAudioConfig, **kwargs):
383
+ r"""
384
+ Instantiate a [`ClapConfig`] (or a derived class) from clap text model configuration and clap audio model
385
+ configuration.
386
+
387
+ Returns:
388
+ [`ClapConfig`]: An instance of a configuration object
389
+ """
390
+
391
+ return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
392
+
393
+
394
+ __all__ = ["ClapAudioConfig", "ClapConfig", "ClapTextConfig"]
docs/transformers/build/lib/transformers/models/clap/convert_clap_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import re
18
+
19
+ from laion_clap import CLAP_Module
20
+
21
+ from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
22
+
23
+
24
+ KEYS_TO_MODIFY_MAPPING = {
25
+ "text_branch": "text_model",
26
+ "audio_branch": "audio_model.audio_encoder",
27
+ "attn": "attention.self",
28
+ "self.proj": "output.dense",
29
+ "attention.self_mask": "attn_mask",
30
+ "mlp.fc1": "intermediate.dense",
31
+ "mlp.fc2": "output.dense",
32
+ "norm1": "layernorm_before",
33
+ "norm2": "layernorm_after",
34
+ "bn0": "batch_norm",
35
+ }
36
+
37
+ processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
38
+
39
+
40
+ def init_clap(checkpoint_path, model_type, enable_fusion=False):
41
+ model = CLAP_Module(
42
+ amodel=model_type,
43
+ enable_fusion=enable_fusion,
44
+ )
45
+ model.load_ckpt(checkpoint_path)
46
+ return model
47
+
48
+
49
+ def get_config_from_original(clap_model):
50
+ audio_config = {
51
+ "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
52
+ "depths": clap_model.model.audio_branch.depths,
53
+ "hidden_size": clap_model.model.audio_projection[0].in_features,
54
+ }
55
+
56
+ text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
57
+
58
+ return ClapConfig(audio_config=audio_config, text_config=text_config)
59
+
60
+
61
+ def rename_state_dict(state_dict):
62
+ model_state_dict = {}
63
+
64
+ sequential_layers_pattern = r".*sequential.(\d+).*"
65
+ text_projection_pattern = r".*_projection.(\d+).*"
66
+
67
+ for key, value in state_dict.items():
68
+ # check if any key needs to be modified
69
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
70
+ if key_to_modify in key:
71
+ key = key.replace(key_to_modify, new_key)
72
+
73
+ if re.match(sequential_layers_pattern, key):
74
+ # replace sequential layers with list
75
+ sequential_layer = re.match(sequential_layers_pattern, key).group(1)
76
+
77
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
78
+ elif re.match(text_projection_pattern, key):
79
+ projecton_layer = int(re.match(text_projection_pattern, key).group(1))
80
+
81
+ # Because in CLAP they use `nn.Sequential`...
82
+ transformers_projection_layer = 1 if projecton_layer == 0 else 2
83
+
84
+ key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
85
+
86
+ if "audio" and "qkv" in key:
87
+ # split qkv into query key and value
88
+ mixed_qkv = value
89
+ qkv_dim = mixed_qkv.size(0) // 3
90
+
91
+ query_layer = mixed_qkv[:qkv_dim]
92
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
93
+ value_layer = mixed_qkv[qkv_dim * 2 :]
94
+
95
+ model_state_dict[key.replace("qkv", "query")] = query_layer
96
+ model_state_dict[key.replace("qkv", "key")] = key_layer
97
+ model_state_dict[key.replace("qkv", "value")] = value_layer
98
+ else:
99
+ model_state_dict[key] = value
100
+
101
+ return model_state_dict
102
+
103
+
104
+ def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
105
+ clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
106
+
107
+ clap_model.eval()
108
+ state_dict = clap_model.model.state_dict()
109
+ state_dict = rename_state_dict(state_dict)
110
+
111
+ transformers_config = get_config_from_original(clap_model)
112
+ transformers_config.audio_config.enable_fusion = enable_fusion
113
+ model = ClapModel(transformers_config)
114
+
115
+ # ignore the spectrogram embedding layer
116
+ model.load_state_dict(state_dict, strict=False)
117
+
118
+ model.save_pretrained(pytorch_dump_folder_path)
119
+ transformers_config.save_pretrained(pytorch_dump_folder_path)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
125
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
126
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
127
+ parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
128
+ parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
129
+ args = parser.parse_args()
130
+
131
+ convert_clap_checkpoint(
132
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
133
+ )
docs/transformers/build/lib/transformers/models/clap/feature_extraction_clap.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for CLAP."""
16
+
17
+ import copy
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from ...audio_utils import mel_filter_bank, spectrogram, window_function
24
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
25
+ from ...feature_extraction_utils import BatchFeature
26
+ from ...utils import TensorType, logging
27
+ from ...utils.import_utils import requires
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ @requires(backends=("torch",))
34
+ class ClapFeatureExtractor(SequenceFeatureExtractor):
35
+ r"""
36
+ Constructs a CLAP feature extractor.
37
+
38
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
39
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
40
+
41
+ This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time
42
+ Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent.
43
+
44
+ Args:
45
+ feature_size (`int`, *optional*, defaults to 64):
46
+ The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters
47
+ (`n_mels`).
48
+ sampling_rate (`int`, *optional*, defaults to 48000):
49
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves
50
+ to warn users if the audio fed to the feature extractor does not have the same sampling rate.
51
+ hop_length (`int`,*optional*, defaults to 480):
52
+ Length of the overlaping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split
53
+ in smaller `frames` with a step of `hop_length` between each frame.
54
+ max_length_s (`int`, *optional*, defaults to 10):
55
+ The maximum input length of the model in seconds. This is used to pad the audio.
56
+ fft_window_size (`int`, *optional*, defaults to 1024):
57
+ Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency
58
+ resolution of the spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples.
59
+ padding_value (`float`, *optional*, defaults to 0.0):
60
+ Padding value used to pad the audio. Should correspond to silences.
61
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
62
+ Whether or not the model should return the attention masks coresponding to the input.
63
+ frequency_min (`float`, *optional*, defaults to 0):
64
+ The lowest frequency of interest. The STFT will not be computed for values below this.
65
+ frequency_max (`float`, *optional*, defaults to 14000):
66
+ The highest frequency of interest. The STFT will not be computed for values above this.
67
+ top_db (`float`, *optional*):
68
+ The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the
69
+ `audio_utils.power_to_db` function
70
+ truncation (`str`, *optional*, defaults to `"fusion"`):
71
+ Truncation pattern for long audio inputs. Two patterns are available:
72
+ - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a
73
+ downsampled version of the entire mel spectrogram.
74
+ If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy
75
+ of the original mel obtained from the padded audio.
76
+ - `rand_trunc` will select a random crop of the mel spectrogram.
77
+ padding (`str`, *optional*, defaults to `"repeatpad"`):
78
+ Padding pattern for shorter audio inputs. Three patterns were originally implemented:
79
+ - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
80
+ - `repeat`: the audio is repeated and then cut to fit the `max_length`
81
+ - `pad`: the audio is padded.
82
+ """
83
+
84
+ model_input_names = ["input_features", "is_longer"]
85
+
86
+ def __init__(
87
+ self,
88
+ feature_size=64,
89
+ sampling_rate=48_000,
90
+ hop_length=480,
91
+ max_length_s=10,
92
+ fft_window_size=1024,
93
+ padding_value=0.0,
94
+ return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
95
+ frequency_min: float = 0,
96
+ frequency_max: float = 14_000,
97
+ top_db: Optional[int] = None,
98
+ truncation: str = "fusion",
99
+ padding: str = "repeatpad",
100
+ **kwargs,
101
+ ):
102
+ super().__init__(
103
+ feature_size=feature_size,
104
+ sampling_rate=sampling_rate,
105
+ padding_value=padding_value,
106
+ return_attention_mask=return_attention_mask,
107
+ **kwargs,
108
+ )
109
+ self.top_db = top_db
110
+ self.truncation = truncation
111
+ self.padding = padding
112
+ self.fft_window_size = fft_window_size
113
+ self.nb_frequency_bins = (fft_window_size >> 1) + 1
114
+ self.hop_length = hop_length
115
+ self.max_length_s = max_length_s
116
+ self.nb_max_samples = max_length_s * sampling_rate
117
+ self.sampling_rate = sampling_rate
118
+ self.frequency_min = frequency_min
119
+ self.frequency_max = frequency_max
120
+ self.mel_filters = mel_filter_bank(
121
+ num_frequency_bins=self.nb_frequency_bins,
122
+ num_mel_filters=feature_size,
123
+ min_frequency=frequency_min,
124
+ max_frequency=frequency_max,
125
+ sampling_rate=sampling_rate,
126
+ norm=None,
127
+ mel_scale="htk",
128
+ )
129
+ self.mel_filters_slaney = mel_filter_bank(
130
+ num_frequency_bins=self.nb_frequency_bins,
131
+ num_mel_filters=feature_size,
132
+ min_frequency=frequency_min,
133
+ max_frequency=frequency_max,
134
+ sampling_rate=sampling_rate,
135
+ norm="slaney",
136
+ mel_scale="slaney",
137
+ )
138
+
139
+ def to_dict(self) -> Dict[str, Any]:
140
+ """
141
+ Serializes this instance to a Python dictionary.
142
+
143
+ Returns:
144
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, excpet for the
145
+ mel filter banks, which do not need to be saved or printed as they are too long.
146
+ """
147
+ output = copy.deepcopy(self.__dict__)
148
+ output["feature_extractor_type"] = self.__class__.__name__
149
+ if "mel_filters" in output:
150
+ del output["mel_filters"]
151
+ if "mel_filters_slaney" in output:
152
+ del output["mel_filters_slaney"]
153
+ return output
154
+
155
+ def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
156
+ """
157
+ Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
158
+ banks are used depending on the truncation pattern:
159
+ - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
160
+ calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
161
+ is set to `"fusion"`.
162
+ - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
163
+ `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
164
+ implementation when the truncation mode is not `"fusion"`.
165
+ """
166
+ log_mel_spectrogram = spectrogram(
167
+ waveform,
168
+ window_function(self.fft_window_size, "hann"),
169
+ frame_length=self.fft_window_size,
170
+ hop_length=self.hop_length,
171
+ power=2.0,
172
+ mel_filters=mel_filters,
173
+ log_mel="dB",
174
+ )
175
+ return log_mel_spectrogram.T
176
+
177
+ def _random_mel_fusion(self, mel, total_frames, chunk_frames):
178
+ ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
179
+ if len(ranges[1]) == 0:
180
+ # if the audio is too short, we just use the first chunk
181
+ ranges[1] = [0]
182
+ if len(ranges[2]) == 0:
183
+ # if the audio is too short, we just use the first chunk
184
+ ranges[2] = [0]
185
+ # randomly choose index for each part
186
+ idx_front = np.random.choice(ranges[0])
187
+ idx_middle = np.random.choice(ranges[1])
188
+ idx_back = np.random.choice(ranges[2])
189
+
190
+ mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
191
+ mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
192
+ mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
193
+
194
+ mel = torch.tensor(mel[None, None, :])
195
+ mel_shrink = torch.nn.functional.interpolate(
196
+ mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
197
+ )
198
+ mel_shrink = mel_shrink[0][0].numpy()
199
+ mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
200
+ return mel_fusion
201
+
202
+ def _get_input_mel(self, waveform: np.array, max_length, truncation, padding) -> np.array:
203
+ """
204
+ Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments.
205
+ Four different path are possible:
206
+ - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram
207
+ will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram
208
+ are then stacked together. They will later be used for `feature_fusion`.
209
+ - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is
210
+ padded based on `padding`.
211
+ - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded
212
+ based on `padding`, and is repeated `4` times.
213
+ - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel
214
+ spectrogram will be computed on a random crop of the waveform.
215
+
216
+ """
217
+ if waveform.shape[0] > max_length:
218
+ if truncation == "rand_trunc":
219
+ longer = True
220
+ # random crop to max_length (for compatibility) -> this should be handled by self.pad
221
+ overflow = len(waveform) - max_length
222
+ idx = np.random.randint(0, overflow + 1)
223
+ waveform = waveform[idx : idx + max_length]
224
+ input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
225
+ elif truncation == "fusion":
226
+ mel = self._np_extract_fbank_features(waveform, self.mel_filters)
227
+ chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed
228
+ total_frames = mel.shape[0]
229
+ if chunk_frames == total_frames:
230
+ # there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length.
231
+ # In this case, we just use the whole audio.
232
+ input_mel = np.stack([mel, mel, mel, mel], axis=0)
233
+ longer = False
234
+ else:
235
+ input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames)
236
+ longer = True
237
+ else:
238
+ raise NotImplementedError(f"data_truncating {truncation} not implemented")
239
+
240
+ else:
241
+ longer = False
242
+ # only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding
243
+ if waveform.shape[0] < max_length:
244
+ if padding == "repeat":
245
+ n_repeat = int(max_length / len(waveform))
246
+ waveform = np.tile(waveform, n_repeat + 1)[:max_length]
247
+ if padding == "repeatpad":
248
+ n_repeat = int(max_length / len(waveform))
249
+ waveform = np.tile(waveform, n_repeat)
250
+ waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0)
251
+
252
+ if truncation == "fusion":
253
+ input_mel = self._np_extract_fbank_features(waveform, self.mel_filters)
254
+ input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0)
255
+ else:
256
+ input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
257
+
258
+ return input_mel, longer
259
+
260
+ def __call__(
261
+ self,
262
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
263
+ truncation: Optional[str] = None,
264
+ padding: Optional[str] = None,
265
+ max_length: Optional[int] = None,
266
+ sampling_rate: Optional[int] = None,
267
+ return_tensors: Optional[Union[str, TensorType]] = None,
268
+ **kwargs,
269
+ ) -> BatchFeature:
270
+ """
271
+ Main method to featurize and prepare for the model one or several sequence(s).
272
+
273
+ Args:
274
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
275
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
276
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
277
+ stereo, i.e. single float per timestep.
278
+ truncation (`str`, *optional*):
279
+ Truncation pattern for long audio inputs. Two patterns are available:
280
+ - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
281
+ a downsampled version of the entire mel spectrogram.
282
+ If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a
283
+ copy of the original mel obtained from the padded audio.
284
+ - `rand_trunc` will select a random crop of the mel spectrogram.
285
+ padding (`str`, *optional*):
286
+ Padding pattern for shorter audio inputs. Three patterns were originally implemented:
287
+ - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
288
+ - `repeat`: the audio is repeated and then cut to fit the `max_length`
289
+ - `pad`: the audio is padded.
290
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
291
+ If set, will return tensors instead of list of python integers. Acceptable values are:
292
+
293
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
294
+ - `'pt'`: Return PyTorch `torch.np.array` objects.
295
+ - `'np'`: Return Numpy `np.ndarray` objects.
296
+ sampling_rate (`int`, *optional*):
297
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
298
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
299
+ pipeline.
300
+ """
301
+ truncation = truncation if truncation is not None else self.truncation
302
+ padding = padding if padding else self.padding
303
+
304
+ if sampling_rate is not None:
305
+ if sampling_rate != self.sampling_rate:
306
+ raise ValueError(
307
+ f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
308
+ f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
309
+ f" was sampled with {self.sampling_rate} and not {sampling_rate}."
310
+ )
311
+ else:
312
+ logger.warning(
313
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
314
+ "Failing to do so can result in silent errors that might be hard to debug."
315
+ )
316
+
317
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
318
+ if is_batched_numpy and len(raw_speech.shape) > 2:
319
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
320
+ is_batched = is_batched_numpy or (
321
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
322
+ )
323
+
324
+ if is_batched:
325
+ raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech]
326
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
327
+ raw_speech = np.asarray(raw_speech, dtype=np.float64)
328
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
329
+ raw_speech = raw_speech.astype(np.float64)
330
+
331
+ # always return batch
332
+ if not is_batched:
333
+ raw_speech = [np.asarray(raw_speech)]
334
+
335
+ # convert to mel spectrogram, truncate and pad if needed.
336
+ padded_inputs = [
337
+ self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding)
338
+ for waveform in raw_speech
339
+ ]
340
+
341
+ input_mel = []
342
+ is_longer = []
343
+ for mel, longer in padded_inputs:
344
+ input_mel.append(mel)
345
+ is_longer.append(longer)
346
+
347
+ if truncation == "fusion" and sum(is_longer) == 0:
348
+ # if no audio is longer than 10s, then randomly select one audio to be longer
349
+ rand_idx = np.random.randint(0, len(input_mel))
350
+ is_longer[rand_idx] = True
351
+
352
+ if isinstance(input_mel[0], List):
353
+ input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel]
354
+
355
+ # is_longer is a list of bool
356
+ is_longer = [[longer] for longer in is_longer]
357
+
358
+ input_features = {"input_features": input_mel, "is_longer": is_longer}
359
+ input_features = BatchFeature(input_features)
360
+
361
+ if return_tensors is not None:
362
+ input_features = input_features.convert_to_tensors(return_tensors)
363
+
364
+ return input_features
365
+
366
+
367
+ __all__ = ["ClapFeatureExtractor"]
docs/transformers/build/lib/transformers/models/clap/modeling_clap.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/clap/processing_clap.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio/Text processor class for CLAP
17
+ """
18
+
19
+ from ...processing_utils import ProcessorMixin
20
+ from ...tokenization_utils_base import BatchEncoding
21
+
22
+
23
+ class ClapProcessor(ProcessorMixin):
24
+ r"""
25
+ Constructs a CLAP processor which wraps a CLAP feature extractor and a RoBerta tokenizer into a single processor.
26
+
27
+ [`ClapProcessor`] offers all the functionalities of [`ClapFeatureExtractor`] and [`RobertaTokenizerFast`]. See the
28
+ [`~ClapProcessor.__call__`] and [`~ClapProcessor.decode`] for more information.
29
+
30
+ Args:
31
+ feature_extractor ([`ClapFeatureExtractor`]):
32
+ The audio processor is a required input.
33
+ tokenizer ([`RobertaTokenizerFast`]):
34
+ The tokenizer is a required input.
35
+ """
36
+
37
+ feature_extractor_class = "ClapFeatureExtractor"
38
+ tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
39
+
40
+ def __init__(self, feature_extractor, tokenizer):
41
+ super().__init__(feature_extractor, tokenizer)
42
+
43
+ def __call__(self, text=None, audios=None, return_tensors=None, **kwargs):
44
+ """
45
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
46
+ and `kwargs` arguments to RobertaTokenizerFast's [`~RobertaTokenizerFast.__call__`] if `text` is not `None` to
47
+ encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to
48
+ ClapFeatureExtractor's [`~ClapFeatureExtractor.__call__`] if `audios` is not `None`. Please refer to the
49
+ docstring of the above two methods for more information.
50
+
51
+ Args:
52
+ text (`str`, `List[str]`, `List[List[str]]`):
53
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
54
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
55
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
56
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
57
+ The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case
58
+ of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels,
59
+ and T the sample length of the audio.
60
+
61
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
62
+ If set, will return tensors of a particular framework. Acceptable values are:
63
+
64
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
65
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
66
+ - `'np'`: Return NumPy `np.ndarray` objects.
67
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
68
+
69
+ Returns:
70
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
71
+
72
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
73
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
74
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
75
+ `None`).
76
+ - **audio_features** -- Audio features to be fed to a model. Returned when `audios` is not `None`.
77
+ """
78
+ sampling_rate = kwargs.pop("sampling_rate", None)
79
+
80
+ if text is None and audios is None:
81
+ raise ValueError("You have to specify either text or audios. Both cannot be none.")
82
+
83
+ if text is not None:
84
+ encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
85
+
86
+ if audios is not None:
87
+ audio_features = self.feature_extractor(
88
+ audios, sampling_rate=sampling_rate, return_tensors=return_tensors, **kwargs
89
+ )
90
+
91
+ if text is not None and audios is not None:
92
+ encoding.update(audio_features)
93
+ return encoding
94
+ elif text is not None:
95
+ return encoding
96
+ else:
97
+ return BatchEncoding(data=dict(**audio_features), tensor_type=return_tensors)
98
+
99
+ def batch_decode(self, *args, **kwargs):
100
+ """
101
+ This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
102
+ refer to the docstring of this method for more information.
103
+ """
104
+ return self.tokenizer.batch_decode(*args, **kwargs)
105
+
106
+ def decode(self, *args, **kwargs):
107
+ """
108
+ This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer
109
+ to the docstring of this method for more information.
110
+ """
111
+ return self.tokenizer.decode(*args, **kwargs)
112
+
113
+ @property
114
+ def model_input_names(self):
115
+ tokenizer_input_names = self.tokenizer.model_input_names
116
+ feature_extractor_input_names = self.feature_extractor.model_input_names
117
+ return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
118
+
119
+
120
+ __all__ = ["ClapProcessor"]
docs/transformers/build/lib/transformers/models/clip/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_clip import *
22
+ from .feature_extraction_clip import *
23
+ from .image_processing_clip import *
24
+ from .image_processing_clip_fast import *
25
+ from .modeling_clip import *
26
+ from .modeling_flax_clip import *
27
+ from .modeling_tf_clip import *
28
+ from .processing_clip import *
29
+ from .tokenization_clip import *
30
+ from .tokenization_clip_fast import *
31
+ else:
32
+ import sys
33
+
34
+ _file = globals()["__file__"]
35
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/clip/convert_clip_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+
18
+ import torch
19
+ from clip import load
20
+
21
+ from transformers import CLIPConfig, CLIPModel
22
+
23
+
24
+ def copy_attn_layer(hf_attn_layer, pt_attn_layer):
25
+ q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
26
+ q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
27
+
28
+ out_proj_weights = pt_attn_layer.out_proj.weight
29
+ out_proj_bias = pt_attn_layer.out_proj.bias
30
+
31
+ hf_attn_layer.q_proj.weight.data = q_proj
32
+ hf_attn_layer.q_proj.bias.data = q_proj_bias
33
+
34
+ hf_attn_layer.k_proj.weight.data = k_proj
35
+ hf_attn_layer.k_proj.bias.data = k_proj_bias
36
+
37
+ hf_attn_layer.v_proj.weight.data = v_proj
38
+ hf_attn_layer.v_proj.bias.data = v_proj_bias
39
+
40
+ hf_attn_layer.out_proj.weight = out_proj_weights
41
+ hf_attn_layer.out_proj.bias = out_proj_bias
42
+
43
+
44
+ def copy_mlp(hf_mlp, pt_mlp):
45
+ copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
46
+ copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
47
+
48
+
49
+ def copy_linear(hf_linear, pt_linear):
50
+ hf_linear.weight = pt_linear.weight
51
+ hf_linear.bias = pt_linear.bias
52
+
53
+
54
+ def copy_layer(hf_layer, pt_layer):
55
+ # copy layer norms
56
+ copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
57
+ copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
58
+
59
+ # copy MLP
60
+ copy_mlp(hf_layer.mlp, pt_layer.mlp)
61
+
62
+ # copy attn
63
+ copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
64
+
65
+
66
+ def copy_layers(hf_layers, pt_layers):
67
+ for hf_layer, pt_layer in zip(hf_layers, pt_layers):
68
+ copy_layer(hf_layer, pt_layer)
69
+
70
+
71
+ def copy_encoder(hf_encoder, pt_model):
72
+ # copy embeds
73
+ hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
74
+ hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
75
+
76
+ # copy layer norm
77
+ copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
78
+
79
+ # copy hidden layers
80
+ copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
81
+
82
+
83
+ def copy_text_model_and_projection(hf_model, pt_model):
84
+ # copy projection
85
+ hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
86
+
87
+ # copy text encoder
88
+ copy_encoder(hf_model.text_model, pt_model)
89
+
90
+
91
+ def copy_vison_model_and_projection(hf_model, pt_model):
92
+ # copy projection
93
+ hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
94
+
95
+ # copy layer norms
96
+ copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
97
+ copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
98
+
99
+ # copy embeds
100
+ hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
101
+ hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
102
+ hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
103
+
104
+ # copy encoder
105
+ copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
106
+
107
+
108
+ @torch.no_grad()
109
+ def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
110
+ """
111
+ Copy/paste/tweak model's weights to transformers design.
112
+ """
113
+ if config_path is not None:
114
+ config = CLIPConfig.from_pretrained(config_path)
115
+ else:
116
+ config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
117
+
118
+ hf_model = CLIPModel(config).eval()
119
+
120
+ pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
121
+ pt_model = pt_model.eval()
122
+
123
+ copy_text_model_and_projection(hf_model, pt_model)
124
+ copy_vison_model_and_projection(hf_model, pt_model)
125
+ hf_model.logit_scale = pt_model.logit_scale
126
+
127
+ # Use `eos_token` so the example is more meaningful
128
+ input_ids = torch.tensor(
129
+ [
130
+ [config.text_config.bos_token_id]
131
+ + list(range(3, 77))
132
+ + [config.text_config.eos_token_id]
133
+ + [config.text_config.pad_token_id]
134
+ ]
135
+ )
136
+ pixel_values = torch.randn(1, 3, 224, 224)
137
+
138
+ hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
139
+ hf_logits_per_image = hf_outputs.logits_per_image
140
+ hf_logits_per_text = hf_outputs.logits_per_text
141
+ pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
142
+
143
+ assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
144
+ assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
145
+
146
+ hf_model.save_pretrained(pytorch_dump_folder_path)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
152
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
153
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
154
+ args = parser.parse_args()
155
+
156
+ convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
old/.ipynb_checkpoints/dataset_10k_train-checkpoint.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f6360a5bc18603afd8cd64d3d7b6e9b5b55b204a53031ce3570be5f01aa05b
3
+ size 16739995
old/dataset_10k_train.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f6360a5bc18603afd8cd64d3d7b6e9b5b55b204a53031ce3570be5f01aa05b
3
+ size 16739995
seamless_interaction/assets/banner.gif ADDED

Git LFS Details

  • SHA256: 6b47141b5f3018e8387671dfe858090c810438902c6e6d72a7022c01e262b08c
  • Pointer size: 133 Bytes
  • Size of remote file: 36.2 MB
swift/llm/template/__pycache__/vision_utils.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
swift/llm/template/template/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
2
+ minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
swift/llm/template/template/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (606 Bytes). View file
 
swift/llm/template/template/__pycache__/deepseek.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
swift/llm/template/template/__pycache__/emu3.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
swift/llm/template/template/__pycache__/gemma.cpython-310.pyc ADDED
Binary file (5.91 kB). View file
 
swift/llm/template/template/__pycache__/glm.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
swift/llm/template/template/__pycache__/idefics3.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
swift/llm/template/template/__pycache__/internlm.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
swift/llm/template/template/__pycache__/internvl.cpython-310.pyc ADDED
Binary file (6.8 kB). View file
 
swift/llm/template/template/__pycache__/llama.cpython-310.pyc ADDED
Binary file (9.74 kB). View file
 
swift/llm/template/template/__pycache__/llava.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
swift/llm/template/template/__pycache__/llm.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
swift/llm/template/template/__pycache__/megrez.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
swift/llm/template/template/__pycache__/microsoft.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc ADDED
Binary file (8.18 kB). View file
 
swift/llm/template/template/__pycache__/minimax.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
swift/llm/template/template/__pycache__/mistral.cpython-310.pyc ADDED
Binary file (2.67 kB). View file
 
swift/llm/template/template/__pycache__/molmo.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
swift/llm/template/template/__pycache__/moonshot.cpython-310.pyc ADDED
Binary file (3.39 kB). View file
 
swift/llm/template/template/__pycache__/mplug.cpython-310.pyc ADDED
Binary file (8.46 kB). View file
 
swift/llm/template/template/__pycache__/openbuddy.cpython-310.pyc ADDED
Binary file (2.44 kB). View file
 
swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc ADDED
Binary file (2.3 kB). View file
 
swift/llm/template/template/__pycache__/qwen.cpython-310.pyc ADDED
Binary file (24.7 kB). View file
 
swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc ADDED
Binary file (6.57 kB). View file
 
swift/llm/template/template/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
swift/llm/template/template/__pycache__/valley.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
swift/llm/template/template/__pycache__/yi.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
swift/llm/template/template/deepseek.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+
11
+ from swift.utils import get_env_args
12
+ from ..base import Template
13
+ from ..constant import LLMTemplateType, MLLMTemplateType
14
+ from ..register import TemplateMeta, register_template
15
+ from ..template_inputs import StdTemplateInputs
16
+ from ..utils import Prompt, findall
17
+
18
+
19
+ @dataclass
20
+ class DeepseekTemplateMeta(TemplateMeta):
21
+ prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
22
+ prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:'])
23
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
24
+ suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
25
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n'])
26
+
27
+
28
+ register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, ))
29
+
30
+ register_template(
31
+ TemplateMeta(
32
+ LLMTemplateType.deepseek_coder,
33
+ prefix=['{{SYSTEM}}'],
34
+ prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'],
35
+ chat_sep=['\n<|EOT|>\n'],
36
+ suffix=['\n<|EOT|>'],
37
+ stop_words=['<|EOT|>'],
38
+ default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, '
39
+ 'developed by Deepseek Company, and you only answer questions related to computer science. '
40
+ 'For politically sensitive questions, security and privacy issues, '
41
+ 'and other non-computer science questions, you will refuse to answer\n')))
42
+
43
+
44
+ class DeepseekVLTemplate(Template):
45
+ image_placeholder = ['<image_placeholder>']
46
+ skip_prompt = False
47
+ use_model = True
48
+ placeholder_tokens = ['<image_placeholder>']
49
+
50
+ image_token_num_per_image: int = 576
51
+
52
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
53
+ is_janus = getattr(self, 'is_janus', False)
54
+
55
+ encoded = super()._encode(inputs)
56
+ images = inputs.images
57
+ processor = self.processor
58
+ input_ids, labels = encoded['input_ids'], encoded['labels']
59
+
60
+ if not inputs.generate_mode: # understanding task
61
+ idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
62
+ new_input_ids, new_labels = [], []
63
+ lo = 0
64
+ for hi in idx_list:
65
+ new_input_ids += input_ids[lo:hi]
66
+ if labels is not None:
67
+ new_labels += labels[lo:hi]
68
+ image_tokens = [processor.image_id] * processor.num_image_tokens
69
+ if is_janus:
70
+ image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
71
+ new_input_ids += image_tokens
72
+ new_labels += [-100] * len(image_tokens)
73
+ lo = hi + 1
74
+ new_input_ids += input_ids[lo:]
75
+ if labels is not None:
76
+ new_labels += labels[lo:]
77
+ else:
78
+ new_labels = None
79
+ if is_janus:
80
+ from janus.models.processing_vlm import VLChatProcessorOutput
81
+ else:
82
+ from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
83
+
84
+ images_outputs = processor.image_processor(images, return_tensors='pt')
85
+ output = VLChatProcessorOutput(
86
+ sft_format=None,
87
+ input_ids=torch.tensor(new_input_ids),
88
+ pixel_values=images_outputs.pixel_values,
89
+ num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
90
+ encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
91
+ return encoded
92
+
93
+ else: # image generation task
94
+ if self.is_training:
95
+ raise NotImplementedError('Only support the inference of generation of Janus series models.')
96
+ sft_format = self.tokenizer.decode(input_ids)
97
+ prompt = sft_format + processor.image_start_tag
98
+ input_ids = processor.tokenizer.encode(prompt)
99
+ input_ids = torch.LongTensor(input_ids)
100
+
101
+ encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
102
+ return encoded
103
+
104
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
105
+ if not inputs.get('generate_mode'):
106
+ inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype)
107
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
108
+ return {'inputs_embeds': inputs_embeds}
109
+ else:
110
+ return inputs
111
+
112
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
113
+ gene_img_list = [b.get('generate_mode') for b in batch]
114
+ if all(gene_img_list):
115
+ generate_mode = True
116
+ elif not any(gene_img_list):
117
+ generate_mode = False
118
+ else:
119
+ raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')
120
+
121
+ if not generate_mode:
122
+ output = self.fetch_inputs(batch, ['output'])['output']
123
+ batched_output = dict(self.processor.batchify(output))
124
+ res = super()._data_collator(batch, padding_to=padding_to)
125
+ return {**batched_output, **res}
126
+ else:
127
+ res = super()._data_collator(batch, padding_to=padding_to)
128
+ res['generate_mode'] = generate_mode
129
+ return res
130
+
131
+ def generate(self, model, *args, **kwargs):
132
+ if not kwargs.get('generate_mode'):
133
+ return super().generate(model, *args, **kwargs)
134
+
135
+ else:
136
+ # generate how many number of images for each prompt, it is named parallel_size in the author's code
137
+ parallel_size = kwargs['generation_config'].num_return_sequences
138
+ temperature = kwargs['generation_config'].temperature
139
+ cfg_weight = get_env_args('cfg_weight', float, 5.0)
140
+
141
+ input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
142
+ bsz, max_input_token_num = input_ids.shape
143
+ tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
144
+ dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
145
+ for i in range(parallel_size * 2):
146
+ tokens[:, i, :] = input_ids
147
+ if i % 2 != 0:
148
+ tokens[:, i, 1:-1] = self.processor.pad_id
149
+
150
+ inputs_embeds = model.language_model.get_input_embeddings()(
151
+ tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]
152
+
153
+ generated_tokens = torch.zeros(
154
+ (bsz, parallel_size, self.image_token_num_per_image),
155
+ dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens
156
+
157
+ # set the first two dimensions into one dimension for batch size
158
+ inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
159
+ generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)
160
+
161
+ for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
162
+ outputs = model.language_model.model(
163
+ inputs_embeds=inputs_embeds,
164
+ use_cache=True,
165
+ past_key_values=outputs.past_key_values if i != 0 else None)
166
+ hidden_states = outputs.last_hidden_state
167
+
168
+ logits = self.model.gen_head(hidden_states[:, -1, :])
169
+ logit_cond = logits[0::2, :]
170
+ logit_uncond = logits[1::2, :]
171
+
172
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
173
+ probs = torch.softmax(logits / temperature, dim=-1)
174
+
175
+ next_token = torch.multinomial(probs, num_samples=1)
176
+ generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]
177
+
178
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
179
+ img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
180
+ inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]
181
+
182
+ # no need to reset the original first two dimensions, waiting for the update of the upper layer
183
+ # inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
184
+ # generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)
185
+
186
+ return {'sequences': generated_tokens}
187
+
188
+ def decode(self, generate_ids: List[int], **kwargs) -> Any:
189
+ if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
190
+ return super().decode(generate_ids, **kwargs)
191
+ else:
192
+ img_size = get_env_args('img_size', int, 384)
193
+ patch_size = 16
194
+
195
+ num_to_decode = 1 # for now, generate_ids is a 1D list
196
+
197
+ generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]
198
+
199
+ dec = self.model.gen_vision_model.decode_code(
200
+ generate_ids.to(dtype=torch.int),
201
+ shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
202
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]
203
+
204
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
205
+
206
+ visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
207
+ visual_img[:, :, :] = dec
208
+
209
+ img_list = []
210
+ for i in range(num_to_decode):
211
+ cur_img = Image.fromarray(visual_img[i])
212
+ img_list.append({'type': 'image', 'image': cur_img})
213
+ return img_list
214
+
215
+
216
+ @dataclass
217
+ class DeepseekVLTemplateMeta(DeepseekTemplateMeta):
218
+ default_system: Optional[str] = ('You are a helpful language and vision assistant. '
219
+ 'You are able to understand the visual content that the user provides, '
220
+ 'and assist the user with a variety of tasks using natural language.')
221
+
222
+
223
+ register_template(DeepseekVLTemplateMeta(
224
+ MLLMTemplateType.deepseek_vl,
225
+ template_cls=DeepseekVLTemplate,
226
+ ))
227
+
228
+
229
+ class DeepseekJanus(DeepseekVLTemplate):
230
+ is_janus = True
231
+ image_placeholder = ['<image_placeholder>\n']
232
+
233
+
234
+ register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus))
235
+
236
+
237
+ @dataclass
238
+ class DeepseekV2_5TemplateMeta(TemplateMeta):
239
+ prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}'])
240
+ prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>'])
241
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
242
+ suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
243
+
244
+
245
+ register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5))
246
+
247
+
248
+ class DeepseekR1Template(Template):
249
+
250
+ def _swift_encode(self, inputs: StdTemplateInputs):
251
+ if not self.is_training:
252
+ for message in inputs.messages:
253
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
254
+ message['content'] = message['content'].split('</think>')[-1]
255
+ return super()._swift_encode(inputs)
256
+
257
+
258
+ register_template(
259
+ DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='<think>\n'))
260
+
261
+
262
+ class DeepseekVL2Template(DeepseekVLTemplate):
263
+ image_placeholder = ['<image>\n']
264
+ placeholder_tokens = ['<image>']
265
+
266
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
267
+ from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput
268
+ encoded = Template._encode(self, inputs)
269
+ images = inputs.images
270
+ processor = self.processor
271
+ input_ids, labels = encoded['input_ids'], encoded['labels']
272
+ images_seq_mask = [False] * len(input_ids)
273
+ idx_list = findall(input_ids, processor.image_token_id) # '<image>'
274
+ _, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images(
275
+ '<image>' * len(images), images, cropping=len(images) <= 2)
276
+ new_num_tokens = 0
277
+ for idx, n_image_tokens in zip(idx_list, num_image_tokens):
278
+ image_tokens = [processor.image_token_id] * n_image_tokens
279
+ input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:]
280
+ if labels is not None:
281
+ labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:]
282
+ images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:]
283
+ new_num_tokens += n_image_tokens - 1
284
+
285
+ output = VLChatProcessorOutput(
286
+ sft_format=None,
287
+ input_ids=torch.tensor(input_ids),
288
+ target_ids=torch.tensor(input_ids),
289
+ images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)),
290
+ images_seq_mask=torch.tensor(images_seq_mask),
291
+ images_spatial_crop=torch.tensor(images_spatial_crop),
292
+ num_image_tokens=num_image_tokens)
293
+ output.images = output.images.to(dtype=self.model_info.torch_dtype)
294
+ encoded = {'output': output, 'input_ids': input_ids, 'labels': labels}
295
+ return encoded
296
+
297
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
298
+ inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool)
299
+ inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long)
300
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
301
+ return {'inputs_embeds': inputs_embeds}
302
+
303
+
304
+ register_template(
305
+ DeepseekV2_5TemplateMeta(
306
+ MLLMTemplateType.deepseek_vl2,
307
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
308
+ template_cls=DeepseekVL2Template,
309
+ ))
310
+
311
+ register_template(
312
+ DeepseekVLTemplateMeta(
313
+ MLLMTemplateType.deepseek_janus_pro,
314
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
315
+ template_cls=DeepseekJanus))