zhiyuan8 commited on
Commit
919416f
·
verified ·
1 Parent(s): ff5066e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/arch.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/perf_speed.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,108 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ <div align="center">
5
+ <a href="https://github.com/MoonshotAI/Kimi-Linear/blob/master/tech_report.pdf"><img width="80%" src="figures/banner.png"></a>
6
+ </div>
7
+
8
+ <div align="center">
9
+ <a href="https://github.com/MoonshotAI/Kimi-Linear/blob/master/tech_report.pdf"><img src="figures/logo.png" height="16" width="16" style="vertical-align:middle"><b> Tech Report</b></a> |
10
+ <a href="https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct"><img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" height="16" width="16" style="vertical-align:middle"><b> HuggingFace</b></a>
11
+ </div>
12
+
13
+
14
+ <div align="center">
15
+ <img width="90%" src="figures/perf_speed.png">
16
+ <p><em><b>(a)</b> On MMLU-Pro (4k context length), Kimi Linear achieves 51.0 performance with similar speed as full attention. On RULER (128k context length), it shows Pareto-optimal performance (84.3) and 3.98x speedup. <b>(b)</b> Kimi Linear achieves 6.3x faster TPOT compared to MLA, offering significant speedups at long sequence lengths (1M tokens).</em></p>
17
+ </div>
18
+
19
+ ## Overview
20
+
21
+ Kimi Linear is a hybrid linear attention architecture that outperforms traditional full attention methods across various contexts, including short, long, and reinforcement learning (RL) scaling regimes.
22
+ At its core is Kimi Delta Attention (KDA)—a refined version of [Gated DeltaNet](https://arxiv.org/abs/2412.06464) that introduces a more efficient gating mechanism to optimize the use of finite-state RNN memory.
23
+
24
+ Kimi Linear achieves superior performance and hardware efficiency, especially for long-context tasks. It reduces the need for large KV caches by up to 75% and boosts decoding throughput by up to $6\times$ for contexts as long as 1M tokens.
25
+
26
+ We open-source the KDA kernel in [FLA](https://github.com/fla-org/flash-linear-attention/tree/main/fla/ops/kda), and release two versions model checkpoints trained with 5.7T tokens.
27
+
28
+
29
+ | **Model** | **#Total Params** | **#Activated Params** | **Context Length** | **Download Link** |
30
+ | :------------------: | :---------------: | :-------------------: | :----------------: | :------------------------------------------------------------------------------: |
31
+ | Kimi-Linear-Base | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Base) |
32
+ | Kimi-Linear-Instruct | 48B | 3B | 1M | [🤗 Hugging Face](https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct) |
33
+
34
+ ## Key Features
35
+
36
+ - **Kimi Delta Attention (KDA):** A linear attention mechanism that refines the gated delta rule with finegrained gating.
37
+ - **Hybrid Architecture:** A 3:1 KDA-to-global MLA ratio reduces memory usage while maintaining or surpassing the quality of full attention.
38
+ - **Superior Performance:** Outperforms full attention in a variety of tasks, including long-context and RL-style benchmarks on 1.4T token training runs with fair comparisons.
39
+ - **High Throughput:** Achieves up to $6\times$ faster decoding and significantly reduces time per output token (TPOT).
40
+
41
+ <div align="center">
42
+ <img width="60%" src="figures/arch.png">
43
+ </div>
44
+
45
+ ## Usage
46
+
47
+ ### Inference with Hugging Face Transformers
48
+
49
+ To use the Kimi Linear model, we recommend the following environment:
50
+
51
+ * `python` >= 3.10
52
+ * `torch` >= 2.6
53
+ * `fla-core` >= 0.4.0
54
+
55
+ ```shell
56
+ pip install -U fla-core
57
+ ```
58
+
59
+ Example Code:
60
+ ```py
61
+ from transformers import AutoModelForCausalLM, AutoTokenizer
62
+
63
+ model_name = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ model_name,
66
+ torch_dtype="auto",
67
+ device_map="auto",
68
+ trust_remote_code=True
69
+ )
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
71
+
72
+ messages = [
73
+ {"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
74
+ {"role": "user", "content": "Is 123 a prime?"}
75
+ ]
76
+ input_ids = tokenizer.apply_chat_template(
77
+ messages,
78
+ add_generation_prompt=True,
79
+ return_tensors="pt"
80
+ ).to(model.device)
81
+ generated_ids = model.generate(inputs=input_ids, max_new_tokens=500)
82
+ response = tokenizer.batch_decode(generated_ids)[0]
83
+ print(response)
84
+ ```
85
+
86
+ ### Deployment
87
+
88
+ For deployment, you can use the latest vllm to create an OpenAI-compatible API endpoint.
89
+
90
+ ```sh
91
+ vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct \
92
+ --port 8000 \
93
+ --tensor-parallel-size 4 \
94
+ --max-model-len 1048576 \
95
+ --trust-remote-code
96
+ ```
97
+
98
+ ### Citation
99
+
100
+ If you found our work useful, please cite
101
+ ```bibtex
102
+ @article{kimi2025kda,
103
+ title = {Kimi Linear: An Expressive, Efficient Attention Architecture},
104
+ author = {kimi Team},
105
+ year = {2025},
106
+ url = {https://github.com/MoonshotAI/Kimi-Linear/blob/master/tech_report.pdf}
107
+ }
108
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(msg) -%}
2
+ {%- set c = msg.get('content') -%}
3
+ {%- if c is string -%}
4
+ {{ c }}
5
+ {%- elif c is not none -%}
6
+ {% for content in c -%}
7
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
8
+ <|media_start|>image<|media_content|><|media_pad|><|media_end|>
9
+ {% else -%}
10
+ {{ content['text'] }}
11
+ {%- endif -%}
12
+ {%- endfor -%}
13
+ {%- endif -%}
14
+ {%- endmacro %}
15
+
16
+
17
+ {%- if tools -%}
18
+ <|im_system|>tool_declare<|im_middle|>{{ tools | tojson(separators=(',', ':')) }}<|im_end|>
19
+ {%- endif -%}
20
+ {% for message in messages %}
21
+ {%- set role_name = message.get('name') or message['role'] -%}
22
+ {%- if message['role'] == 'user' -%}
23
+ <|im_user|>{{role_name}}<|im_middle|>
24
+ {%- elif message['role'] == 'assistant' -%}
25
+ <|im_assistant|>{{role_name}}<|im_middle|>
26
+ {%- else -%}
27
+ <|im_system|>{{role_name}}<|im_middle|>
28
+ {%- endif -%}
29
+
30
+ {%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
31
+ {{render_content(message)}}<|tool_calls_section_begin|>
32
+ {%- for tool_call in message['tool_calls'] -%}
33
+ {%- set formatted_id = tool_call['id'] -%}
34
+ <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
35
+ {%- endfor -%}
36
+ <|tool_calls_section_end|>
37
+ {%- elif message['role'] == 'tool' -%}
38
+ {%- set tool_call_id = message.tool_call_id -%}
39
+ ## Return of {{ tool_call_id }}
40
+ {{render_content(message)}}
41
+ {%- elif message['content'] is not none -%}
42
+ {{render_content(message)}}
43
+ {%- endif -%}
44
+ <|im_end|>
45
+ {%- endfor -%}
46
+ {%- if add_generation_prompt -%}
47
+ <|im_assistant|>assistant<|im_middle|>
48
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "KimiLinearForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_kimi.KimiLinearConfig",
7
+ "AutoModel": "modeling_kimi.KimiLinearModel",
8
+ "AutoModelForCausalLM": "modeling_kimi.KimiLinearForCausalLM"
9
+ },
10
+ "bos_token_id": 163584,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 163586,
13
+ "first_k_dense_replace": 1,
14
+ "head_dim": 72,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2304,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 9216,
19
+ "kv_lora_rank": 512,
20
+ "linear_attn_config": {
21
+ "full_attn_layers": [
22
+ 4,
23
+ 8,
24
+ 12,
25
+ 16,
26
+ 20,
27
+ 24,
28
+ 27
29
+ ],
30
+ "head_dim": 128,
31
+ "kda_layers": [
32
+ 1,
33
+ 2,
34
+ 3,
35
+ 5,
36
+ 6,
37
+ 7,
38
+ 9,
39
+ 10,
40
+ 11,
41
+ 13,
42
+ 14,
43
+ 15,
44
+ 17,
45
+ 18,
46
+ 19,
47
+ 21,
48
+ 22,
49
+ 23,
50
+ 25,
51
+ 26
52
+ ],
53
+ "num_heads": 32,
54
+ "short_conv_kernel_size": 4
55
+ },
56
+ "mla_use_nope": true,
57
+ "model_max_length": 1048576,
58
+ "model_type": "kimi_linear",
59
+ "moe_intermediate_size": 1024,
60
+ "moe_layer_freq": 1,
61
+ "moe_renormalize": true,
62
+ "moe_router_activation_func": "sigmoid",
63
+ "num_attention_heads": 32,
64
+ "num_expert_group": 1,
65
+ "num_experts": 256,
66
+ "num_experts_per_token": 8,
67
+ "num_hidden_layers": 27,
68
+ "num_key_value_heads": 32,
69
+ "num_nextn_predict_layers": 0,
70
+ "num_shared_experts": 1,
71
+ "pad_token_id": 163839,
72
+ "q_lora_rank": null,
73
+ "qk_nope_head_dim": 128,
74
+ "qk_rope_head_dim": 64,
75
+ "rms_norm_eps": 1e-05,
76
+ "rope_scaling": null,
77
+ "rope_theta": 10000.0,
78
+ "routed_scaling_factor": 2.446,
79
+ "tie_word_embeddings": false,
80
+ "topk_group": 1,
81
+ "transformers_version": "4.57.1",
82
+ "use_cache": true,
83
+ "use_grouped_topk": true,
84
+ "v_head_dim": 128,
85
+ "vocab_size": 163840
86
+ }
configuration_kimi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class KimiLinearConfig(PretrainedConfig):
8
+ model_type = "kimi_linear"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ model_type="kimi_linear",
14
+ vocab_size=163840,
15
+ hidden_size=4096,
16
+ head_dim=None,
17
+ intermediate_size=11008,
18
+ num_hidden_layers=32,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ hidden_act="silu",
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ bos_token_id=1,
27
+ eos_token_id=2,
28
+ rope_theta=10000.0,
29
+ rope_scaling=None,
30
+ tie_word_embeddings=False,
31
+ moe_intermediate_size: Optional[int] = None,
32
+ moe_renormalize: bool = True,
33
+ moe_router_activation_func: str = "sigmoid",
34
+ num_experts: Optional[int] = None,
35
+ num_experts_per_token: Optional[int] = None,
36
+ num_shared_experts: int = 0,
37
+ routed_scaling_factor: float = 1.0,
38
+ first_k_dense_replace: int = 0,
39
+ moe_layer_freq: int = 1,
40
+ use_grouped_topk: bool = True,
41
+ num_expert_group: int = 1,
42
+ topk_group: int = 1,
43
+ q_lora_rank: Optional[int] = None,
44
+ kv_lora_rank: Optional[int] = None,
45
+ qk_nope_head_dim: Optional[int] = None,
46
+ qk_rope_head_dim: Optional[int] = None,
47
+ v_head_dim: Optional[int] = None,
48
+ mla_use_nope: Optional[bool] = False,
49
+ num_nextn_predict_layers: int = 0,
50
+ linear_attn_config: Optional[dict] = None,
51
+ **kwargs,
52
+ ):
53
+ self.model_type = model_type
54
+ self.vocab_size = vocab_size
55
+ self.hidden_size = hidden_size
56
+ self.head_dim = (
57
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
58
+ )
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+
63
+ # for backward compatibility
64
+ if num_key_value_heads is None:
65
+ num_key_value_heads = num_attention_heads
66
+
67
+ self.num_key_value_heads = num_key_value_heads
68
+ self.hidden_act = hidden_act
69
+ self.initializer_range = initializer_range
70
+ self.rms_norm_eps = rms_norm_eps
71
+ self.use_cache = use_cache
72
+ self.rope_theta = rope_theta
73
+ self.rope_scaling = rope_scaling
74
+
75
+ self.q_lora_rank = q_lora_rank
76
+ self.kv_lora_rank = kv_lora_rank
77
+ self.qk_nope_head_dim = qk_nope_head_dim
78
+ self.qk_rope_head_dim = qk_rope_head_dim
79
+ self.v_head_dim = v_head_dim
80
+ self.mla_use_nope = mla_use_nope
81
+ # moe config
82
+ self.num_experts = num_experts
83
+ self.num_experts_per_token = num_experts_per_token
84
+ self.moe_renormalize = moe_renormalize
85
+ self.num_shared_experts = num_shared_experts
86
+ self.routed_scaling_factor = routed_scaling_factor
87
+ self.moe_router_activation_func = moe_router_activation_func
88
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
89
+ self.moe_intermediate_size = moe_intermediate_size
90
+ self.first_k_dense_replace = first_k_dense_replace
91
+ self.moe_layer_freq = moe_layer_freq
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.num_expert_group = num_expert_group
94
+ self.topk_group = topk_group
95
+ self.num_nextn_predict_layers = num_nextn_predict_layers
96
+
97
+ if linear_attn_config is not None:
98
+ assert linear_attn_config["kda_layers"] is not None
99
+ assert linear_attn_config["full_attn_layers"] is not None
100
+ self.linear_attn_config = linear_attn_config
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
109
+
110
+ @property
111
+ def is_mla(self):
112
+ return (
113
+ self.q_lora_rank is not None
114
+ or self.kv_lora_rank is not None
115
+ or self.qk_nope_head_dim is not None
116
+ or self.qk_rope_head_dim is not None
117
+ or self.v_head_dim is not None
118
+ or self.mla_use_nope is True
119
+ )
120
+
121
+ @property
122
+ def is_moe(self):
123
+ return self.num_experts is not None
124
+
125
+ @property
126
+ def is_linear_attn(self) -> bool:
127
+ return not (
128
+ self.linear_attn_config is None
129
+ or (
130
+ isinstance(self.linear_attn_config, dict)
131
+ and self.linear_attn_config["kda_layers"] is not None
132
+ and len(self.linear_attn_config["kda_layers"]) == 0
133
+ )
134
+ )
135
+
136
+ def is_kda_layer(self, layer_idx: int):
137
+ return (
138
+ self.linear_attn_config is not None
139
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
140
+ )
figures/arch.png ADDED

Git LFS Details

  • SHA256: 132ae021fa4661ed39e7be784d46f05f22b82aabb9afd2bab8dbdc0a5a61cba0
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
figures/banner.png ADDED
figures/logo.png ADDED
figures/perf_speed.png ADDED

Git LFS Details

  • SHA256: f8951e618db41ae57fa0cec4845d7b275dffbd7f9db12c6496bfea536c625aea
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 163584,
4
+ "eos_token_id": 163586,
5
+ "pad_token_id": 163839,
6
+ "transformers_version": "4.57.1"
7
+ }
model-00001-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5c908aa3b86b6486080b577cb7aa8dbe9ca7cb18789653768017e602b61a7f
3
+ size 4999482712
model-00002-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fcb34e9ebe2434f32761c06ef17a465157308e6e583eb7eb70cc25e57cd2cb0
3
+ size 4999923264
model-00003-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f35d2a95dd1e3170fd642d0db4d0d07933985ef59041494652092cc27893e231
3
+ size 4997138040
model-00004-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eda18b6226777bb9a07584dfa64986ac4f28a26cee3203f16ffb14deef9ef48b
3
+ size 4997148016
model-00005-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:652e8a43d493105176807d256af0a5c56e45c6d783e6c8221832918f3425c0a0
3
+ size 4999923296
model-00006-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77f5dc551436c934f0991eee0c319e0f33689ea3c10cb8cb8f48acc32238526f
3
+ size 4997138040
model-00007-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3980c7efccdc6a27633eb8909afc381cb780cccaa7be0347d1645496ea3eb5a2
3
+ size 4997148128
model-00008-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dab91b8eaed9874c75de99a4a08669f520fa3e2c8977175333db552504a1c5d3
3
+ size 4999924384
model-00009-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13f6ae84d557682ec4a0fc8b6090d4f89cdd26e5e216445cc9d77a65c7f4c90b
3
+ size 4997139104
model-00010-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d58ab0a201e26ff429b9d18678a76f3a3284ad977719e055f94d892133ee247b
3
+ size 4997149016
model-00011-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:142aa90317af104b2d9f5a6ae4dc661f4a7f7c152f83d6c2477de8037be92201
3
+ size 4999924408
model-00012-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb1d4fe2d94a04898eb8300b5144923e5540a75091ee5f4c8b67936a69d91780
3
+ size 4997139104
model-00013-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12fb6f2dea889d460f33f7fbb55f76d7beb468698cf56707f4d77a9ab69461d3
3
+ size 4997148992
model-00014-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf178169e0dfdc721492f1a98a5be2ef5f66fd8569039f5a77819641a5a1b32d
3
+ size 4999924440
model-00015-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e65460934e8794faeadd6d8cbeffd23fcdbf07d9c61ca92ef97afc95d0ccdaa
3
+ size 4997139104
model-00016-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05cc77846f94d50dc09180f5844f01aee38e489b1fd833d8c7aec6a62214ef03
3
+ size 4997148960
model-00017-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaf238a2f1a971ef311c445309de323992da59287d55759cf2a4a3a85ca6a1cc
3
+ size 4999924472
model-00018-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:315cdb6964a975522cdb755cf5eb76b46478346b015113e241f81127ad9e6fd4
3
+ size 4997139104
model-00019-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbdc2c77e41baa76a2c2b3ced0a59fe7587e95ca3d1acc75247b88a80dee3041
3
+ size 4999934384
model-00020-of-00020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f1e4a9194d045e01c90ed2697939bcedd533b6aa1f1b97b0ae0a5932e5a4bc7
3
+ size 3280687152
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_kimi.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from einops import rearrange
9
+ from packaging import version
10
+ from torch import nn
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.masking_utils import create_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast)
18
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
19
+ PreTrainedModel)
20
+ from transformers.processing_utils import Unpack
21
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
22
+ from transformers.utils import (TransformersKwargs, auto_docstring,
23
+ can_return_tuple, logging)
24
+ from transformers.utils.generic import OutputRecorder, check_model_inputs
25
+
26
+ try:
27
+ from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
28
+ from fla.modules import FusedRMSNormGated, ShortConvolution
29
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
30
+ from fla.ops.kda.gate import fused_kda_gate
31
+ except ImportError:
32
+ raise ImportError("Plese run `pip install -U fla-core`")
33
+
34
+ from .configuration_kimi import KimiLinearConfig
35
+
36
+ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
37
+ "Please upgrade transformers to >= 4.56.0"
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class KimiDynamicCache:
43
+ """
44
+ Dynamic cache for Kimi model.
45
+ Inspired by Qwen3-Next
46
+ """
47
+ is_compileable = False
48
+
49
+ def __init__(self, config: KimiLinearConfig):
50
+ super().__init__()
51
+ self.config = config
52
+
53
+ if config.linear_attn_config is not None:
54
+ self.layer_types = []
55
+ for i in range(config.num_hidden_layers):
56
+ if config.is_kda_layer(i):
57
+ self.layer_types.append("linear_attention")
58
+ else:
59
+ self.layer_types.append("full_attention")
60
+ else:
61
+ self.layer_types = ["full_attention"] * config.num_hidden_layers
62
+
63
+ self.transformer_layers = [
64
+ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
65
+ ]
66
+
67
+ linear_layers = [i for i in range(
68
+ config.num_hidden_layers) if self.layer_types[i] == "linear_attention"]
69
+ self.last_linear_layer = linear_layers[-1] if linear_layers else -1
70
+
71
+ self.conv_states = [None for _ in range(config.num_hidden_layers)]
72
+ self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
73
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
74
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
75
+
76
+ def __len__(self):
77
+ return len(self.layer_types)
78
+
79
+ def update(
80
+ self,
81
+ key_states: torch.Tensor,
82
+ value_states: torch.Tensor,
83
+ layer_idx: int,
84
+ cache_kwargs: Optional[dict[str, Any]] = None,
85
+ ) -> tuple[torch.Tensor, torch.Tensor]:
86
+ if self.key_cache[layer_idx] is None:
87
+ self.key_cache[layer_idx] = key_states
88
+ self.value_cache[layer_idx] = value_states
89
+ else:
90
+ self.key_cache[layer_idx] = torch.cat(
91
+ [self.key_cache[layer_idx], key_states], dim=2)
92
+ self.value_cache[layer_idx] = torch.cat(
93
+ [self.value_cache[layer_idx], value_states], dim=2)
94
+
95
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
96
+
97
+ def reorder_cache(self, beam_idx: torch.LongTensor):
98
+ """Reorders the cache for beam search, given the selected beam indices."""
99
+ for layer_idx in range(len(self.key_cache)):
100
+ if self.key_cache[layer_idx] is not None:
101
+ device = self.key_cache[layer_idx].device
102
+ beam_idx = beam_idx.to(device)
103
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
104
+ 0, beam_idx)
105
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
106
+ 0, beam_idx)
107
+
108
+ if self.conv_states[layer_idx] is not None:
109
+ device = self.conv_states[layer_idx][0].device
110
+ beam_idx = beam_idx.to(device)
111
+ q_conv, k_conv, v_conv = self.conv_states[layer_idx]
112
+ self.conv_states[layer_idx] = (
113
+ q_conv.index_select(0, beam_idx),
114
+ k_conv.index_select(0, beam_idx),
115
+ v_conv.index_select(0, beam_idx)
116
+ )
117
+ self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
118
+ 0, beam_idx)
119
+
120
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
121
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
122
+ # take any layer that contains cache and not empty tensor
123
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
124
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
125
+ return 0
126
+ return self.key_cache[layer_idx].shape[-2]
127
+
128
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
129
+ """
130
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
131
+ the given layer at `layer_idx`.
132
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
133
+ """
134
+ kv_offset = 0
135
+ query_length = cache_position.shape[0]
136
+ past_seen_tokens = self.get_seq_length(layer_idx)
137
+ kv_length = query_length + past_seen_tokens
138
+ return kv_length, kv_offset
139
+
140
+ @property
141
+ def has_previous_state(self):
142
+ """We have a previous state if the last linear (conv) layer was already updated."""
143
+ if self.last_linear_layer == -1:
144
+ return False
145
+ return self.conv_states[self.last_linear_layer] is not None
146
+
147
+
148
+ class KimiRMSNorm(nn.Module):
149
+ def __init__(self, hidden_size, eps=1e-6):
150
+ """
151
+ KimiRMSNorm is equivalent to T5LayerNorm
152
+ """
153
+ super().__init__()
154
+ self.weight = nn.Parameter(torch.ones(hidden_size))
155
+ self.variance_epsilon = eps
156
+
157
+ def forward(self, hidden_states):
158
+ input_dtype = hidden_states.dtype
159
+ hidden_states = hidden_states.to(torch.float32)
160
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
161
+ hidden_states = hidden_states * \
162
+ torch.rsqrt(variance + self.variance_epsilon)
163
+ return self.weight * hidden_states.to(input_dtype)
164
+
165
+
166
+ ALL_LAYERNORM_LAYERS.append(KimiRMSNorm)
167
+
168
+
169
+ class KimiBlockSparseMLP(nn.Module):
170
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
171
+ super().__init__()
172
+ self.config = config
173
+ self.ffn_dim = config.intermediate_size if intermediate_size is None else intermediate_size
174
+ self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size
175
+
176
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate
177
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down
178
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up
179
+
180
+ self.act_fn = ACT2FN[config.hidden_act]
181
+
182
+ def forward(self, hidden_states):
183
+ current_hidden_states = self.act_fn(
184
+ self.w1(hidden_states)) * self.w3(hidden_states)
185
+ current_hidden_states = self.w2(current_hidden_states)
186
+ return current_hidden_states
187
+
188
+
189
+ class KimiMLP(nn.Module):
190
+ def __init__(self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None):
191
+ super().__init__()
192
+ self.config = config
193
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
194
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
195
+ self.gate_proj = nn.Linear(
196
+ self.hidden_size, self.intermediate_size, bias=False)
197
+ self.up_proj = nn.Linear(
198
+ self.hidden_size, self.intermediate_size, bias=False)
199
+ self.down_proj = nn.Linear(
200
+ self.intermediate_size, self.hidden_size, bias=False)
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(self, x):
204
+ down_proj = self.down_proj(self.act_fn(
205
+ self.gate_proj(x)) * self.up_proj(x))
206
+ return down_proj
207
+
208
+
209
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
210
+ """
211
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
212
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
213
+ """
214
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
215
+ if n_rep == 1:
216
+ return hidden_states
217
+ hidden_states = hidden_states[:, :, None, :, :].expand(
218
+ batch, num_key_value_heads, n_rep, slen, head_dim)
219
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
220
+
221
+
222
+ def eager_attention_forward(
223
+ module: nn.Module,
224
+ query: torch.Tensor,
225
+ key: torch.Tensor,
226
+ value: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor],
228
+ scaling: float,
229
+ dropout: float = 0.0,
230
+ **kwargs: Unpack[TransformersKwargs],
231
+ ):
232
+ key_states = repeat_kv(key, module.num_key_value_groups)
233
+ value_states = repeat_kv(value, module.num_key_value_groups)
234
+
235
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
236
+ if attention_mask is not None:
237
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
238
+ attn_weights = attn_weights + causal_mask
239
+
240
+ attn_weights = nn.functional.softmax(
241
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
242
+ attn_weights = nn.functional.dropout(
243
+ attn_weights, p=dropout, training=module.training)
244
+ attn_output = torch.matmul(attn_weights, value_states)
245
+ attn_output = attn_output.transpose(1, 2).contiguous()
246
+
247
+ return attn_output, attn_weights
248
+
249
+
250
+ class KimiMLAAttention(nn.Module):
251
+ """
252
+ Multi-Latent Attention adapted from deepseek-v3
253
+ """
254
+
255
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
256
+ nn.Module.__init__(self)
257
+ self.config = config
258
+ self.layer_idx = layer_idx
259
+ self.hidden_size = config.hidden_size
260
+ self.num_heads = config.num_attention_heads
261
+ self.num_key_value_heads = config.num_key_value_heads
262
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
263
+
264
+ self.rope_theta = config.rope_theta
265
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
266
+
267
+ try:
268
+ self.q_lora_rank = config.q_lora_rank
269
+ self.qk_rope_head_dim = config.qk_rope_head_dim
270
+ self.kv_lora_rank = config.kv_lora_rank
271
+ self.v_head_dim = config.v_head_dim
272
+ self.qk_nope_head_dim = config.qk_nope_head_dim
273
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
274
+ self.use_nope = config.mla_use_nope
275
+ self.scaling = self.q_head_dim ** (-0.5)
276
+ except Exception as e:
277
+ raise ValueError(
278
+ f"Kimi MLA config is not found or not properly formatted: {e}")
279
+
280
+ assert self.q_lora_rank is None
281
+ self.q_proj = nn.Linear(
282
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False,
283
+ )
284
+ self.kv_a_proj_with_mqa = nn.Linear(
285
+ self.hidden_size,
286
+ self.kv_lora_rank + self.qk_rope_head_dim,
287
+ bias=False,
288
+ )
289
+ self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank)
290
+ self.kv_b_proj = nn.Linear(
291
+ self.kv_lora_rank,
292
+ self.num_heads
293
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
294
+ bias=False,
295
+ )
296
+ self.o_proj = nn.Linear(
297
+ self.num_heads * self.v_head_dim,
298
+ self.hidden_size,
299
+ bias=False,
300
+ )
301
+ self.is_causal = True
302
+ assert self.use_nope
303
+
304
+ def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ past_key_values: Optional[Cache] = None,
309
+ **kwargs,
310
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ batch_size, seq_length = hidden_states.shape[:-1]
312
+ query_shape = (batch_size, seq_length, -1, self.q_head_dim)
313
+ key_shape = (batch_size, seq_length, -1,
314
+ self.qk_nope_head_dim + self.v_head_dim)
315
+
316
+ q_states = self.q_proj(hidden_states)
317
+ q_states = q_states.view(query_shape).transpose(1, 2)
318
+ q_pass, q_rot = torch.split(
319
+ q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
320
+
321
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
322
+ k_pass, k_rot = torch.split(
323
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
324
+
325
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(
326
+ k_pass)).view(key_shape).transpose(1, 2)
327
+ k_pass, value_states = torch.split(
328
+ k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
329
+
330
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
331
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
332
+
333
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
334
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
335
+
336
+ if past_key_values is not None:
337
+ key_states, value_states = past_key_values.update(
338
+ key_states, value_states, self.layer_idx)
339
+
340
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
341
+ value_states = F.pad(
342
+ value_states, [0, self.q_head_dim - self.v_head_dim])
343
+
344
+ attention_interface: Callable = eager_attention_forward
345
+ if self.config._attn_implementation != "eager":
346
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
347
+
348
+ attn_output, _ = attention_interface(
349
+ self,
350
+ query_states,
351
+ key_states,
352
+ value_states,
353
+ attention_mask,
354
+ dropout=0.0 if not self.training else self.attention_dropout,
355
+ scaling=self.scaling,
356
+ **kwargs,
357
+ )
358
+
359
+ if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
360
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
361
+
362
+ attn_output = attn_output.reshape(
363
+ batch_size, seq_length, -1).contiguous()
364
+ attn_output = self.o_proj(attn_output)
365
+ return attn_output
366
+
367
+
368
+ class KimiDeltaAttention(nn.Module):
369
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
370
+ super().__init__()
371
+ self.config = config
372
+ self.mode = "chunk"
373
+
374
+ self.hidden_size = config.hidden_size
375
+ self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
376
+ self.head_dim = config.linear_attn_config["head_dim"]
377
+ self.num_heads = config.linear_attn_config["num_heads"]
378
+ self.head_k_dim = self.head_dim
379
+ self.num_k_heads = self.num_heads
380
+
381
+ self.layer_idx = layer_idx
382
+
383
+ assert self.mode in [
384
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
385
+
386
+ projection_k_size = self.head_k_dim * self.num_k_heads
387
+ projection_size = self.head_dim * self.num_heads
388
+
389
+ self.q_proj = nn.Linear(
390
+ self.hidden_size, projection_k_size, bias=False)
391
+ self.k_proj = nn.Linear(
392
+ self.hidden_size, projection_k_size, bias=False)
393
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
394
+
395
+ self.q_conv1d = ShortConvolution(
396
+ hidden_size=projection_k_size,
397
+ kernel_size=self.conv_size,
398
+ activation='silu',
399
+ )
400
+ self.k_conv1d = ShortConvolution(
401
+ hidden_size=projection_k_size,
402
+ kernel_size=self.conv_size,
403
+ activation='silu'
404
+ )
405
+ self.v_conv1d = ShortConvolution(
406
+ hidden_size=projection_size,
407
+ kernel_size=self.conv_size,
408
+ activation='silu'
409
+ )
410
+
411
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
412
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
413
+
414
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
415
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
416
+
417
+ self.dt_bias = nn.Parameter(
418
+ torch.empty(projection_size, dtype=torch.float32))
419
+
420
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
421
+
422
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
423
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
424
+
425
+ self.o_norm = FusedRMSNormGated(
426
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
427
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ cache_params: Optional[KimiDynamicCache] = None,
434
+ **kwargs: Unpack[dict]
435
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
436
+ if attention_mask is not None:
437
+ if attention_mask.dim() != 2:
438
+ attention_mask = kwargs.get("padding_mask", None)
439
+
440
+ if attention_mask is not None and attention_mask.dim() != 2:
441
+ raise ValueError(
442
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
443
+ "(0 = padding). 3D masks are not supported here."
444
+ )
445
+ use_cache = cache_params is not None
446
+ batch_size, q_len, _ = hidden_states.shape
447
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
448
+ if self.training:
449
+ assert mode == 'chunk', "Only chunk mode is supported in training."
450
+
451
+ cu_seqlens = kwargs.get('cu_seqlens', None)
452
+ indices = None
453
+ if attention_mask is not None:
454
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
455
+ hidden_states = index_first_axis(
456
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
457
+
458
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
459
+ recurrent_state = None
460
+ if cache_params is not None:
461
+ if cache_params.conv_states[self.layer_idx] is not None:
462
+ conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[
463
+ self.layer_idx]
464
+ recurrent_state = cache_params.recurrent_states[self.layer_idx]
465
+ q, conv_state_q = self.q_conv1d(
466
+ x=self.q_proj(hidden_states),
467
+ cache=conv_state_q,
468
+ output_final_state=use_cache,
469
+ cu_seqlens=cu_seqlens
470
+ )
471
+ k, conv_state_k = self.k_conv1d(
472
+ x=self.k_proj(hidden_states),
473
+ cache=conv_state_k,
474
+ output_final_state=use_cache,
475
+ cu_seqlens=cu_seqlens
476
+ )
477
+ v, conv_state_v = self.v_conv1d(
478
+ x=self.v_proj(hidden_states),
479
+ cache=conv_state_v,
480
+ output_final_state=use_cache,
481
+ cu_seqlens=cu_seqlens
482
+ )
483
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
484
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
485
+ beta = self.b_proj(hidden_states).float().sigmoid()
486
+
487
+ q, k = map(lambda x: rearrange(
488
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
489
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
490
+
491
+ if mode == 'chunk':
492
+ o, recurrent_state = chunk_kda(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ beta=beta,
498
+ initial_state=recurrent_state,
499
+ output_final_state=True,
500
+ use_qk_l2norm_in_kernel=True,
501
+ cu_seqlens=cu_seqlens,
502
+ )
503
+ else:
504
+ o, recurrent_state = fused_recurrent_kda(
505
+ q=q,
506
+ k=k,
507
+ v=v,
508
+ g=g,
509
+ beta=beta,
510
+ initial_state=recurrent_state,
511
+ output_final_state=True,
512
+ use_qk_l2norm_in_kernel=True,
513
+ cu_seqlens=cu_seqlens,
514
+ )
515
+ if cache_params is not None:
516
+ cache_params.recurrent_states[self.layer_idx] = recurrent_state
517
+ cache_params.conv_states[self.layer_idx] = (
518
+ conv_state_q, conv_state_k, conv_state_v)
519
+
520
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
521
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
522
+ o = self.o_norm(o, g)
523
+
524
+ o = rearrange(o, 'b t h d -> b t (h d)')
525
+ o = self.o_proj(o)
526
+ if attention_mask is not None:
527
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
528
+
529
+ return o
530
+
531
+
532
+ class KimiMoEGate(nn.Module):
533
+ """
534
+ MoEGate adapted from Deepseek-V3.
535
+ Parameter correspondences:
536
+ num_experts -> n_routed_experts
537
+ num_experts_per_token -> num_experts_per_tok
538
+ num_expert_group -> n_group
539
+ moe_router_activation_func -> scoring_func
540
+ """
541
+
542
+ def __init__(self, config: KimiLinearConfig):
543
+ super().__init__()
544
+ self.config = config
545
+ self.top_k = config.num_experts_per_token
546
+ self.num_experts = config.num_experts
547
+ self.routed_scaling_factor = config.routed_scaling_factor
548
+ self.moe_router_activation_func = config.moe_router_activation_func
549
+ self.num_expert_group = getattr(config, "num_expert_group", 1)
550
+ self.topk_group = getattr(config, "topk_group", 1)
551
+
552
+ # topk selection algorithm
553
+ self.moe_renormalize = config.moe_renormalize
554
+ self.gating_dim = config.hidden_size
555
+ self.weight = nn.Parameter(
556
+ torch.empty((self.num_experts, self.gating_dim))
557
+ )
558
+
559
+ self.e_score_correction_bias = nn.Parameter(
560
+ torch.empty((self.num_experts))
561
+ )
562
+ self.reset_parameters()
563
+
564
+ def reset_parameters(self) -> None:
565
+ import torch.nn.init as init
566
+
567
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
568
+
569
+ def forward(self, hidden_states):
570
+ bsz, seq_len, h = hidden_states.shape
571
+ # compute gating score
572
+ hidden_states = hidden_states.view(-1, h)
573
+ logits = F.linear(
574
+ hidden_states.type(torch.float32), self.weight.type(
575
+ torch.float32), None
576
+ )
577
+ if self.moe_router_activation_func == "sigmoid":
578
+ scores = logits.sigmoid()
579
+ elif self.moe_router_activation_func == "softmax":
580
+ scores = logits.softmax(dim=1)
581
+ else:
582
+ raise NotImplementedError(
583
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}"
584
+ )
585
+
586
+ # select top-k experts
587
+ assert not self.training
588
+ scores_for_choice = scores.view(bsz * seq_len, -1)
589
+ scores_for_choice += self.e_score_correction_bias.unsqueeze(0)
590
+ group_scores = (
591
+ scores_for_choice.view(
592
+ bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
593
+ ) # [n, num_expert_group]
594
+ group_idx = torch.topk(
595
+ group_scores, k=self.topk_group, dim=-1, sorted=False
596
+ )[
597
+ 1
598
+ ] # [n, top_k_group]
599
+ group_mask = torch.zeros_like(group_scores) # [n, num_expert_group]
600
+ group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group]
601
+ score_mask = (
602
+ group_mask.unsqueeze(-1)
603
+ .expand(
604
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group
605
+ )
606
+ .reshape(bsz * seq_len, -1)
607
+ ) # [n, e]
608
+ tmp_scores = scores_for_choice.masked_fill(
609
+ ~score_mask.bool(), 0.0) # [n, e]
610
+ _, topk_idx = torch.topk(
611
+ tmp_scores, k=self.top_k, dim=-1, sorted=False
612
+ )
613
+ topk_weight = scores.gather(1, topk_idx)
614
+
615
+ # norm gate to sum 1
616
+ if self.top_k > 1 and self.moe_renormalize:
617
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
618
+ topk_weight = topk_weight / denominator
619
+ # must multiply the scaling factor
620
+ topk_weight = topk_weight * self.routed_scaling_factor
621
+
622
+ return topk_idx, topk_weight
623
+
624
+
625
+ class KimiSparseMoeBlock(nn.Module):
626
+ """
627
+ Adapted from Deepseek-V3's MOE implementation
628
+ The namings are consistent with Kimi's version.
629
+ """
630
+
631
+ def __init__(self, config: KimiLinearConfig):
632
+ super().__init__()
633
+ self.config = config
634
+ self.hidden_dim = config.hidden_size
635
+ self.num_experts = config.num_experts
636
+ self.top_k = config.num_experts_per_token
637
+ self.moe_renormalize = config.moe_renormalize
638
+
639
+ self.ep_size = 1
640
+ self.experts_per_rank = config.num_experts
641
+ self.ep_rank = 0
642
+ self.experts = nn.ModuleList(
643
+ [
644
+ KimiBlockSparseMLP(
645
+ config, intermediate_size=config.moe_intermediate_size
646
+ )
647
+ for _ in range(config.num_experts)
648
+ ]
649
+ )
650
+ self.gate = KimiMoEGate(config)
651
+ if config.num_shared_experts is not None:
652
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
653
+ self.shared_experts = KimiMLP(
654
+ config=config, intermediate_size=intermediate_size
655
+ )
656
+
657
+ def forward(self, hidden_states):
658
+ identity = hidden_states
659
+ orig_shape = hidden_states.shape
660
+ topk_idx, topk_weight = self.gate(hidden_states)
661
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
662
+ flat_topk_idx = topk_idx.view(-1)
663
+ if not self.training:
664
+ y = self.moe_infer(hidden_states, topk_idx,
665
+ topk_weight).view(*orig_shape)
666
+ else:
667
+ raise NotImplementedError(
668
+ "Training mode is not supported in KimiSparseMoeBlock")
669
+ if self.config.num_shared_experts is not None:
670
+ y = y + self.shared_experts(identity)
671
+ return y
672
+
673
+ @torch.no_grad()
674
+ def moe_infer(self, x, topk_ids, topk_weight):
675
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
676
+ cnts.scatter_(1, topk_ids, 1)
677
+ tokens_per_expert = cnts.sum(dim=0)
678
+ idxs = topk_ids.view(-1).argsort()
679
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
680
+
681
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
682
+
683
+ outputs = []
684
+ start_idx = 0
685
+ for i, num_tokens in enumerate(tokens_per_expert):
686
+ end_idx = start_idx + num_tokens
687
+ if num_tokens == 0:
688
+ continue
689
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
690
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
691
+ expert_out = expert(tokens_for_this_expert)
692
+ outputs.append(expert_out)
693
+ start_idx = end_idx
694
+
695
+ outs = torch.cat(outputs, dim=0) if len(
696
+ outputs) else sorted_tokens.new_empty(0)
697
+
698
+ new_x = torch.empty_like(outs)
699
+ new_x[idxs] = outs
700
+ final_out = (
701
+ new_x.view(*topk_ids.shape, -1)
702
+ .type(topk_weight.dtype)
703
+ .mul_(topk_weight.unsqueeze(dim=-1))
704
+ .sum(dim=1)
705
+ .type(new_x.dtype)
706
+ )
707
+ return final_out
708
+
709
+
710
+ class KimiDecoderLayer(nn.Module):
711
+ def __init__(self, config: KimiLinearConfig, layer_idx: int):
712
+ super().__init__()
713
+ self.hidden_size = config.hidden_size
714
+ self.config = config
715
+ if config.is_kda_layer(layer_idx):
716
+ self.is_linear_attn = True
717
+ self.self_attn = KimiDeltaAttention(
718
+ config=config, layer_idx=layer_idx)
719
+ elif config.is_mla:
720
+ self.is_linear_attn = False
721
+ self.self_attn = KimiMLAAttention(
722
+ config=config, layer_idx=layer_idx)
723
+ else:
724
+ raise NotImplementedError
725
+ if (
726
+ config.num_experts is not None
727
+ and layer_idx >= config.first_k_dense_replace
728
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
729
+ ):
730
+ self.block_sparse_moe = KimiSparseMoeBlock(config)
731
+ else:
732
+ self.mlp = KimiMLP(config)
733
+ self.input_layernorm = KimiRMSNorm(
734
+ config.hidden_size, eps=config.rms_norm_eps)
735
+ self.post_attention_layernorm = KimiRMSNorm(
736
+ config.hidden_size, eps=config.rms_norm_eps)
737
+
738
+ def forward(
739
+ self,
740
+ hidden_states: torch.Tensor,
741
+ attention_mask: Optional[torch.Tensor] = None,
742
+ position_ids: Optional[torch.LongTensor] = None,
743
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
744
+ output_attentions: Optional[bool] = False,
745
+ use_cache: Optional[bool] = False,
746
+ **kwargs: Unpack[FlashAttentionKwargs],
747
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ output_attentions (`bool`, *optional*):
754
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
755
+ returned tensors for more detail.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
+ (see `past_key_values`).
759
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
760
+ """
761
+
762
+ residual = hidden_states
763
+
764
+ hidden_states = self.input_layernorm(hidden_states)
765
+
766
+ # Self Attention
767
+ if self.is_linear_attn is False:
768
+ hidden_states = self.self_attn(
769
+ hidden_states=hidden_states,
770
+ attention_mask=attention_mask,
771
+ position_ids=position_ids,
772
+ past_key_values=past_key_values,
773
+ output_attentions=output_attentions,
774
+ use_cache=use_cache,
775
+ **kwargs,
776
+ )
777
+ else:
778
+ hidden_states = self.self_attn(
779
+ hidden_states=hidden_states,
780
+ attention_mask=attention_mask,
781
+ cache_params=past_key_values,
782
+ output_attentions=output_attentions,
783
+ use_cache=use_cache,
784
+ **kwargs,
785
+ )
786
+ hidden_states = residual + hidden_states
787
+
788
+ # Fully Connected
789
+ residual = hidden_states
790
+ hidden_states = self.post_attention_layernorm(hidden_states)
791
+ if hasattr(self, "block_sparse_moe"):
792
+ hidden_states = self.block_sparse_moe(hidden_states)
793
+ else:
794
+ hidden_states = self.mlp(hidden_states)
795
+ hidden_states = residual + hidden_states
796
+
797
+ return hidden_states
798
+
799
+
800
+ class KimiPreTrainedModel(PreTrainedModel):
801
+ config_class = KimiLinearConfig
802
+ base_model_prefix = "model"
803
+ supports_gradient_checkpointing = True
804
+ _no_split_modules = ["KimiDecoderLayer"]
805
+ _skip_keys_device_placement = "past_key_values"
806
+ _supports_flash_attn_2 = True
807
+ _can_record_outputs = {
808
+ "router_logits": OutputRecorder(KimiBlockSparseMLP, index=1),
809
+ "hidden_states": KimiDecoderLayer,
810
+ "attentions": KimiMLAAttention,
811
+ }
812
+ _is_stateful = True
813
+
814
+ def _init_weights(self, module):
815
+ std = self.config.initializer_range
816
+ if isinstance(module, nn.Linear):
817
+ module.weight.data.normal_(mean=0.0, std=std)
818
+ if module.bias is not None:
819
+ module.bias.data.zero_()
820
+ elif isinstance(module, nn.Embedding):
821
+ module.weight.data.normal_(mean=0.0, std=std)
822
+ if module.padding_idx is not None:
823
+ module.weight.data[module.padding_idx].zero_()
824
+
825
+
826
+ class KimiLinearModel(KimiPreTrainedModel):
827
+ def __init__(self, config: KimiLinearConfig):
828
+ super().__init__(config)
829
+ self.padding_idx = config.pad_token_id
830
+ self.vocab_size = config.vocab_size
831
+
832
+ self.embed_tokens = nn.Embedding(
833
+ config.vocab_size, config.hidden_size, self.padding_idx)
834
+ self.layers = nn.ModuleList([KimiDecoderLayer(
835
+ config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
836
+ self.norm = KimiRMSNorm(
837
+ config.hidden_size, eps=config.rms_norm_eps)
838
+
839
+ if getattr(config, "_attn_implementation", None) is not None:
840
+ if config._attn_implementation != "flash_attention_2":
841
+ logger.warning_once(
842
+ f"Ignoring the provided attention implementation {config._attn_implementation}")
843
+ logger.warning_once("Using flash_attention_2 backend instead.")
844
+ config._attn_implementation = "flash_attention_2"
845
+ else:
846
+ config._attn_implementation = "flash_attention_2"
847
+
848
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
849
+ self.gradient_checkpointing = False
850
+ # Initialize weights and apply final processing
851
+ self.post_init()
852
+
853
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
854
+ """
855
+ NOTE: Left-padding is used for linear attention mask.
856
+ No need for zeroing states when
857
+ 1. Cached forward
858
+ 2. Attending to all inputs
859
+ """
860
+ linear_attn_mask = attention_mask
861
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
862
+ linear_attn_mask = None
863
+ return linear_attn_mask
864
+
865
+ @check_model_inputs
866
+ @auto_docstring
867
+ def forward(
868
+ self,
869
+ input_ids: torch.LongTensor = None,
870
+ attention_mask: Optional[torch.Tensor] = None,
871
+ position_ids: Optional[torch.LongTensor] = None,
872
+ past_key_values: Optional[Cache] = None,
873
+ inputs_embeds: Optional[torch.FloatTensor] = None,
874
+ cache_position: Optional[torch.LongTensor] = None,
875
+ use_cache: Optional[bool] = None,
876
+ **kwargs: Unpack[TransformersKwargs],
877
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
878
+
879
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
880
+
881
+ if (input_ids is None) and (inputs_embeds is None):
882
+ raise ValueError(
883
+ "You must specify exactly one of input_ids or inputs_embeds")
884
+
885
+ # Get inputs_embeds
886
+ if inputs_embeds is None:
887
+ inputs_embeds = self.embed_tokens(input_ids)
888
+
889
+ if use_cache and past_key_values is None:
890
+ past_key_values = KimiDynamicCache(config=self.config)
891
+
892
+ if cache_position is None:
893
+ past_seen_tokens = past_key_values.get_seq_length(
894
+ ) if past_key_values is not None else 0
895
+ cache_position: torch.Tensor = torch.arange(
896
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
897
+ )
898
+
899
+ if position_ids is None:
900
+ position_ids = cache_position.unsqueeze(0)
901
+
902
+ causal_mask = create_causal_mask(
903
+ config=self.config,
904
+ input_embeds=inputs_embeds,
905
+ attention_mask=attention_mask,
906
+ cache_position=cache_position,
907
+ past_key_values=past_key_values,
908
+ position_ids=position_ids,
909
+ )
910
+ linear_attn_mask = self._update_linear_attn_mask(
911
+ attention_mask, cache_position)
912
+
913
+ hidden_states = inputs_embeds
914
+ if past_key_values is not None:
915
+ assert isinstance(past_key_values, KimiDynamicCache)
916
+
917
+ for decoder_layer in self.layers:
918
+ layer_mask = linear_attn_mask if decoder_layer.is_linear_attn else causal_mask
919
+
920
+ hidden_states = decoder_layer(
921
+ hidden_states,
922
+ attention_mask=layer_mask,
923
+ past_key_values=past_key_values,
924
+ cache_position=cache_position,
925
+ **kwargs,
926
+ )
927
+
928
+ hidden_states = self.norm(hidden_states)
929
+
930
+ return BaseModelOutputWithPast(
931
+ last_hidden_state=hidden_states,
932
+ past_key_values=past_key_values,
933
+ )
934
+
935
+
936
+ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
937
+ _tied_weights_keys = ["lm_head.weight"]
938
+
939
+ def __init__(self, config):
940
+ super().__init__(config)
941
+ self.model = KimiLinearModel(config)
942
+ self.vocab_size = config.vocab_size
943
+ self.lm_head = nn.Linear(
944
+ config.hidden_size, config.vocab_size, bias=False)
945
+
946
+ # Initialize weights and apply final processing
947
+ self.post_init()
948
+
949
+ @can_return_tuple
950
+ @auto_docstring
951
+ def forward(
952
+ self,
953
+ input_ids: torch.LongTensor = None,
954
+ attention_mask: Optional[torch.Tensor] = None,
955
+ position_ids: Optional[torch.LongTensor] = None,
956
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
957
+ inputs_embeds: Optional[torch.FloatTensor] = None,
958
+ labels: Optional[torch.LongTensor] = None,
959
+ use_cache: Optional[bool] = None,
960
+ output_attentions: Optional[bool] = None,
961
+ output_hidden_states: Optional[bool] = None,
962
+ generation_mode: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ cache_position: Optional[torch.LongTensor] = None,
965
+ **kwargs: Unpack[TransformersKwargs],
966
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
967
+ r"""
968
+ Args:
969
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
970
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
971
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
972
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
973
+
974
+ Returns:
975
+
976
+ Example:
977
+
978
+ ```python
979
+ >>> from transformers import AutoTokenizer, KimiLinearForCausalLM
980
+
981
+ >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
982
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
983
+
984
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
985
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
986
+
987
+ >>> # Generate
988
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
989
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
990
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
991
+ ```"""
992
+
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
994
+ output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
996
+ )
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
+
999
+ outputs = self.model(
1000
+ input_ids=input_ids,
1001
+ attention_mask=attention_mask,
1002
+ position_ids=position_ids,
1003
+ past_key_values=past_key_values,
1004
+ inputs_embeds=inputs_embeds,
1005
+ use_cache=use_cache,
1006
+ output_attentions=output_attentions,
1007
+ output_hidden_states=output_hidden_states,
1008
+ return_dict=return_dict,
1009
+ cache_position=cache_position,
1010
+ )
1011
+
1012
+ logits = outputs[0]
1013
+ if generation_mode:
1014
+ logits = logits[:, -1:]
1015
+ logits = self.lm_head(logits)
1016
+
1017
+ loss = None
1018
+ if labels is not None:
1019
+ loss = self.loss_function(
1020
+ logits, labels, self.vocab_size, **kwargs)
1021
+
1022
+ return CausalLMOutputWithPast(
1023
+ loss=loss,
1024
+ logits=logits,
1025
+ past_key_values=outputs.past_key_values,
1026
+ hidden_states=outputs.hidden_states,
1027
+ attentions=outputs.attentions,
1028
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[extra_id_0]",
4
+ "[extra_id_1]",
5
+ "[extra_id_2]",
6
+ "[extra_id_3]",
7
+ "[start_header_id]",
8
+ "[end_header_id]",
9
+ "[extra_id_4]",
10
+ "[EOT]",
11
+ "[extra_id_5]",
12
+ "[extra_id_6]",
13
+ "[extra_id_7]",
14
+ "[extra_id_8]",
15
+ "[extra_id_9]",
16
+ "[extra_id_10]",
17
+ "[extra_id_11]",
18
+ "[extra_id_12]",
19
+ "[extra_id_13]",
20
+ "[extra_id_14]",
21
+ "[extra_id_15]",
22
+ "[extra_id_16]",
23
+ "[extra_id_17]",
24
+ "[extra_id_18]",
25
+ "[extra_id_19]",
26
+ "[extra_id_20]",
27
+ "[extra_id_21]",
28
+ "[extra_id_22]",
29
+ "[extra_id_23]",
30
+ "[extra_id_24]",
31
+ "[extra_id_25]",
32
+ "[extra_id_26]",
33
+ "[extra_id_27]",
34
+ "[extra_id_28]",
35
+ "[extra_id_29]",
36
+ "[extra_id_30]",
37
+ "[extra_id_31]",
38
+ "[extra_id_32]",
39
+ "[extra_id_33]",
40
+ "[extra_id_34]",
41
+ "[extra_id_35]",
42
+ "[extra_id_36]",
43
+ "[extra_id_37]",
44
+ "[extra_id_38]",
45
+ "[extra_id_39]",
46
+ "[extra_id_40]",
47
+ "[extra_id_41]",
48
+ "[extra_id_42]",
49
+ "[extra_id_43]",
50
+ "[extra_id_44]",
51
+ "[extra_id_45]",
52
+ "[extra_id_46]",
53
+ "[extra_id_47]",
54
+ "[extra_id_48]",
55
+ "[extra_id_49]",
56
+ "[extra_id_50]",
57
+ "[extra_id_51]",
58
+ "[extra_id_52]",
59
+ "[extra_id_53]",
60
+ "[extra_id_54]",
61
+ "[extra_id_55]",
62
+ "[extra_id_56]",
63
+ "[extra_id_57]",
64
+ "[extra_id_58]",
65
+ "[extra_id_59]",
66
+ "[extra_id_60]",
67
+ "[extra_id_61]",
68
+ "[extra_id_62]",
69
+ "[extra_id_63]",
70
+ "[extra_id_64]",
71
+ "[extra_id_65]",
72
+ "[extra_id_66]",
73
+ "[extra_id_67]",
74
+ "[extra_id_68]",
75
+ "[extra_id_69]",
76
+ "[extra_id_70]",
77
+ "[extra_id_71]",
78
+ "[extra_id_72]",
79
+ "[extra_id_73]",
80
+ "[extra_id_74]",
81
+ "[extra_id_75]",
82
+ "[extra_id_76]",
83
+ "[extra_id_77]",
84
+ "[extra_id_78]",
85
+ "[extra_id_79]",
86
+ "[extra_id_80]",
87
+ "[extra_id_81]",
88
+ "[extra_id_82]",
89
+ "[extra_id_83]",
90
+ "[extra_id_84]",
91
+ "[extra_id_85]",
92
+ "[extra_id_86]",
93
+ "[extra_id_87]",
94
+ "[extra_id_88]",
95
+ "[extra_id_89]",
96
+ "[extra_id_90]",
97
+ "[extra_id_91]",
98
+ "[extra_id_92]",
99
+ "[extra_id_93]",
100
+ "[extra_id_94]",
101
+ "[extra_id_95]",
102
+ "[extra_id_96]",
103
+ "[extra_id_97]",
104
+ "[extra_id_98]",
105
+ "[extra_id_99]",
106
+ "[extra_id_100]",
107
+ "[extra_id_101]",
108
+ "[extra_id_102]",
109
+ "[extra_id_103]",
110
+ "[extra_id_104]",
111
+ "[extra_id_105]",
112
+ "[extra_id_106]",
113
+ "[extra_id_107]",
114
+ "[extra_id_108]",
115
+ "[extra_id_109]",
116
+ "[extra_id_110]",
117
+ "[extra_id_111]",
118
+ "[extra_id_112]",
119
+ "[extra_id_113]",
120
+ "[extra_id_114]",
121
+ "[extra_id_115]",
122
+ "[extra_id_116]",
123
+ "[extra_id_117]",
124
+ "[extra_id_118]",
125
+ "[extra_id_119]",
126
+ "[extra_id_120]",
127
+ "[extra_id_121]",
128
+ "[extra_id_122]",
129
+ "[extra_id_123]",
130
+ "[extra_id_124]",
131
+ "[extra_id_125]",
132
+ "[extra_id_126]",
133
+ "[extra_id_127]",
134
+ "[extra_id_128]",
135
+ "[extra_id_129]",
136
+ "[extra_id_130]",
137
+ "[extra_id_131]",
138
+ "[extra_id_132]",
139
+ "[extra_id_133]",
140
+ "[extra_id_134]",
141
+ "[extra_id_135]",
142
+ "[extra_id_136]",
143
+ "[extra_id_137]",
144
+ "[extra_id_138]",
145
+ "[extra_id_139]",
146
+ "[extra_id_140]",
147
+ "[extra_id_141]",
148
+ "[extra_id_142]",
149
+ "[extra_id_143]",
150
+ "[extra_id_144]",
151
+ "[extra_id_145]",
152
+ "[extra_id_146]",
153
+ "[extra_id_147]",
154
+ "[extra_id_148]",
155
+ "[extra_id_149]",
156
+ "[extra_id_150]",
157
+ "[extra_id_151]",
158
+ "[extra_id_152]",
159
+ "[extra_id_153]",
160
+ "[extra_id_154]",
161
+ "[extra_id_155]",
162
+ "[extra_id_156]",
163
+ "[extra_id_157]",
164
+ "[extra_id_158]",
165
+ "[extra_id_159]",
166
+ "[extra_id_160]",
167
+ "[extra_id_161]",
168
+ "[extra_id_162]",
169
+ "[extra_id_163]",
170
+ "[extra_id_164]",
171
+ "[extra_id_165]",
172
+ "[extra_id_166]",
173
+ "[extra_id_167]",
174
+ "[extra_id_168]",
175
+ "[extra_id_169]",
176
+ "[extra_id_170]",
177
+ "[extra_id_171]",
178
+ "[extra_id_172]",
179
+ "[extra_id_173]",
180
+ "[extra_id_174]",
181
+ "[extra_id_175]",
182
+ "[extra_id_176]",
183
+ "[extra_id_177]",
184
+ "[extra_id_178]",
185
+ "[extra_id_179]",
186
+ "[extra_id_180]",
187
+ "[extra_id_181]",
188
+ "[extra_id_182]",
189
+ "[extra_id_183]",
190
+ "[extra_id_184]",
191
+ "[extra_id_185]",
192
+ "[extra_id_186]",
193
+ "[extra_id_187]",
194
+ "[extra_id_188]",
195
+ "[extra_id_189]",
196
+ "[extra_id_190]",
197
+ "[extra_id_191]",
198
+ "[extra_id_192]",
199
+ "[extra_id_193]",
200
+ "[extra_id_194]",
201
+ "[extra_id_195]",
202
+ "[extra_id_196]",
203
+ "[extra_id_197]",
204
+ "[extra_id_198]",
205
+ "[extra_id_199]",
206
+ "[extra_id_200]",
207
+ "[extra_id_201]",
208
+ "[extra_id_202]",
209
+ "[extra_id_203]",
210
+ "[extra_id_204]",
211
+ "[extra_id_205]",
212
+ "[extra_id_206]",
213
+ "[extra_id_207]",
214
+ "[extra_id_208]",
215
+ "[extra_id_209]",
216
+ "[extra_id_210]",
217
+ "[extra_id_211]",
218
+ "[extra_id_212]",
219
+ "[extra_id_213]",
220
+ "[extra_id_214]",
221
+ "[extra_id_215]",
222
+ "[extra_id_216]",
223
+ "[extra_id_217]",
224
+ "[extra_id_218]",
225
+ "[extra_id_219]",
226
+ "[extra_id_220]",
227
+ "[extra_id_221]",
228
+ "[extra_id_222]",
229
+ "[extra_id_223]",
230
+ "[extra_id_224]",
231
+ "[extra_id_225]",
232
+ "[extra_id_226]",
233
+ "[extra_id_227]",
234
+ "[extra_id_228]",
235
+ "[extra_id_229]",
236
+ "[extra_id_230]",
237
+ "[extra_id_231]",
238
+ "[extra_id_232]",
239
+ "[extra_id_233]",
240
+ "[extra_id_234]",
241
+ "[extra_id_235]",
242
+ "[extra_id_236]",
243
+ "[extra_id_237]",
244
+ "[extra_id_238]",
245
+ "[extra_id_239]",
246
+ "[extra_id_240]",
247
+ "[extra_id_241]",
248
+ "[extra_id_242]",
249
+ "[extra_id_243]",
250
+ "[extra_id_244]",
251
+ "[extra_id_245]",
252
+ "[extra_id_246]",
253
+ "[extra_id_247]",
254
+ "[extra_id_248]"
255
+ ],
256
+ "bos_token": "[BOS]",
257
+ "eos_token": "[EOS]",
258
+ "pad_token": "[extra_id_250]",
259
+ "unk_token": "[extra_id_249]"
260
+ }
tiktoken.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6c497a7469b33ced9c38afb1ad6e47f03f5e5dc05f15930799210ec050c5103
3
+ size 2795286
tokenization_kimi.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tiktoken
3
+
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import (
7
+ cast,
8
+ Tuple,
9
+ Dict,
10
+ Iterator,
11
+ List,
12
+ Union,
13
+ Optional,
14
+ )
15
+ from shutil import copyfile
16
+ from tiktoken.load import load_tiktoken_bpe
17
+ from tokenizers import AddedToken, pre_tokenizers, Regex
18
+ from transformers.tokenization_utils import PreTrainedTokenizer
19
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
20
+ from typing import Any
21
+
22
+
23
+ logger = getLogger(__name__)
24
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
25
+
26
+
27
+ class TikTokenTokenizer(PreTrainedTokenizer):
28
+ """
29
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
30
+
31
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
32
+ this superclass for more information regarding those methods.
33
+
34
+ Args:
35
+ vocab_file (`str`):
36
+ The path to the Tiktoken model file.
37
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
38
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
39
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
40
+ The end of sequence token.
41
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
42
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
43
+ token instead. The second to last item in special_tokens.
44
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
45
+ The token used for padding, for example when batching sequences of different lengths.
46
+ additional_special_tokens (list of `str`, *optional*):
47
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
48
+ skipped when decoding if `skip_special_tokens` is set to `True`.
49
+ """
50
+
51
+ vocab_files_names = VOCAB_FILES_NAMES
52
+
53
+ model_input_names = ["input_ids", "attention_mask"]
54
+
55
+ special_tokens: Dict[str, int]
56
+
57
+ num_reserved_special_tokens = 256
58
+
59
+ pat_str = "|".join(
60
+ [
61
+ r"""[\p{Han}]+""",
62
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
63
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
64
+ r"""\p{N}{1,3}""",
65
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
66
+ r"""\s*[\r\n]+""",
67
+ r"""\s+(?!\S)""",
68
+ r"""\s+""",
69
+ ]
70
+ )
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file,
75
+ bos_token: Union[str, AddedToken]="[BOS]",
76
+ eos_token: Union[str, AddedToken]="[EOS]",
77
+ unk_token: Union[str, AddedToken, None]=None,
78
+ pad_token: Union[str, AddedToken, None]=None,
79
+ additional_special_tokens: List[str]=None,
80
+ added_tokens_decoder: Optional[dict] = None,
81
+ **kwargs,
82
+ ):
83
+ assert os.path.isfile(vocab_file), vocab_file
84
+
85
+ if additional_special_tokens is None:
86
+ additional_special_tokens = [
87
+ "<|im_end|>",
88
+ "<|im_user|>",
89
+ "<|im_assistant|>",
90
+ "<|start_header_id|>",
91
+ "<|end_header_id|>",
92
+ "[EOT]",
93
+ "<|im_system|>",
94
+ "<|im_middle|>",
95
+ ]
96
+
97
+ special_tokens_mapping = {
98
+ i: added_tokens_decoder[i].content for i in added_tokens_decoder
99
+ }
100
+
101
+ self.vocab_file = vocab_file
102
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
103
+ num_base_tokens = len(mergeable_ranks)
104
+ self.special_tokens = {
105
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
106
+ for i in range(
107
+ num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
108
+ )
109
+ }
110
+
111
+
112
+
113
+ self.model = tiktoken.Encoding(
114
+ name=Path(vocab_file).name,
115
+ pat_str=self.pat_str,
116
+ mergeable_ranks=mergeable_ranks,
117
+ special_tokens=self.special_tokens,
118
+ )
119
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
120
+
121
+ self.n_words: int = self.model.n_vocab
122
+ # BOS / EOS token IDs
123
+ self.bos_id: int = self.special_tokens[str(bos_token)]
124
+ self.eos_id: int = self.special_tokens[str(eos_token)]
125
+ logger.info(
126
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
127
+ )
128
+
129
+ self.pad_id: int = self.special_tokens[str(pad_token)]
130
+ self.unk_id: int = self.special_tokens[str(unk_token)]
131
+
132
+ self.byte_encoder = bytes_to_unicode()
133
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
134
+
135
+ self.decoder = {}
136
+ for i in range(self.n_words):
137
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
138
+ decoding = ''.join([
139
+ self.byte_encoder[ord(char)] for char in
140
+ self.model.decode_single_token_bytes(i).decode('latin-1')
141
+ ])
142
+ self.decoder[i] = decoding
143
+
144
+ self.encoder = {}
145
+ for i in range(self.n_words):
146
+ if i in self.decoder:
147
+ self.encoder[self.decoder[i]] = i
148
+
149
+ super().__init__(
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ pad_token=pad_token,
154
+ additional_special_tokens=additional_special_tokens,
155
+ **kwargs,
156
+ )
157
+ self.all_special_ids_set = set(self.all_special_ids)
158
+
159
+ def encode(
160
+ self,
161
+ text: str,
162
+ allow_special_tokens: bool = True,
163
+ **kwargs
164
+ ) -> List[int]:
165
+ """
166
+ Encodes a string into a list of token IDs.
167
+
168
+ Args:
169
+ text (str): The input string to be encoded.
170
+
171
+ Returns:
172
+ list[int]: A list of token IDs.
173
+ """
174
+ # If there are other args, we should call super().encode because there are a lot of code
175
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
176
+ # NOTE: our encode method is not compatible with the super().encode method,
177
+ # e.g. split_special_tokens' default is True in our encode method.
178
+ if len(kwargs) > 0:
179
+ logger.warning( f"Calling super().encode with {kwargs}" )
180
+ return super().encode(text, **kwargs)
181
+
182
+ assert type(text) is str
183
+
184
+ # The tiktoken tokenizer can handle <=400k chars without
185
+ # pyo3_runtime.PanicException.
186
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
187
+
188
+ # https://github.com/openai/tiktoken/issues/195
189
+ # Here we iterate over subsequences and split if we exceed the limit
190
+ # of max consecutive non-whitespace or whitespace characters.
191
+ MAX_NO_WHITESPACES_CHARS = 25_000
192
+
193
+ texts = self.pre_tokenizer_process(text)
194
+
195
+ all_substrs = []
196
+ for text in texts:
197
+ substrs = (
198
+ substr
199
+ for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
200
+ for substr in self._split_whitespaces_or_nonwhitespaces(
201
+ text[i: i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
202
+ )
203
+ )
204
+ all_substrs.extend(substrs)
205
+
206
+ t: List[int] = []
207
+ for substr in all_substrs:
208
+ if allow_special_tokens:
209
+ t.extend(
210
+ # we should consider special token as a common token
211
+ self.model.encode(
212
+ substr,
213
+ allowed_special="all",
214
+ )
215
+ )
216
+ else:
217
+ t.extend(
218
+ # we should consider special token as a common token
219
+ self.model.encode(
220
+ substr,
221
+ disallowed_special=(),
222
+ )
223
+ )
224
+
225
+ return t
226
+
227
+ def decode(
228
+ self,
229
+ token_ids: Union[int, List[int]],
230
+ **kwargs
231
+ ) -> str:
232
+ """
233
+ Decodes a list of token IDs into a string.
234
+
235
+ Args:
236
+ token_ids (List[int]): The list of token IDs to be decoded.
237
+
238
+ Returns:
239
+ str: The decoded string.
240
+ """
241
+ # If there are other args, we should call super().decode because there are a lot of code
242
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
243
+ if len(kwargs) > 0:
244
+ return super().decode(token_ids, **kwargs)
245
+
246
+ if type(token_ids) is int:
247
+ token_ids = [token_ids]
248
+
249
+ return self.model.decode(cast(List[int], token_ids))
250
+
251
+ @staticmethod
252
+ def _split_whitespaces_or_nonwhitespaces(
253
+ s: str, max_consecutive_slice_len: int
254
+ ) -> Iterator[str]:
255
+ """
256
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
257
+ consecutive whitespaces or consecutive non-whitespaces.
258
+ """
259
+ current_slice_len = 0
260
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
261
+ slice_start = 0
262
+
263
+ for i in range(len(s)):
264
+ is_now_space = s[i].isspace()
265
+
266
+ if current_slice_is_space ^ is_now_space:
267
+ current_slice_len = 1
268
+ current_slice_is_space = is_now_space
269
+ else:
270
+ current_slice_len += 1
271
+ if current_slice_len > max_consecutive_slice_len:
272
+ yield s[slice_start:i]
273
+ slice_start = i
274
+ current_slice_len = 1
275
+ yield s[slice_start:]
276
+
277
+ def pre_tokenizer_process(self, text: str) -> List[str]:
278
+ """
279
+ pre-tokenizes the input text into a list of tokens.
280
+ This method is used to split the input text into smaller chunks for internal processing.
281
+ """
282
+ return [text]
283
+
284
+
285
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
286
+ @property
287
+ def vocab_size(self) -> int:
288
+ return self.n_words
289
+
290
+ def get_vocab(self) -> Dict[str, int]:
291
+ return self.encoder
292
+
293
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
294
+ return [
295
+ self.decoder[t]
296
+ for t in self.encode(text)
297
+ ]
298
+
299
+ def _convert_token_to_id(self, token: str) -> int:
300
+ return self.encoder.get(token, self.unk_id)
301
+
302
+ def _convert_id_to_token(self, index: int) -> str:
303
+ return self.decoder.get(index)
304
+
305
+ @staticmethod
306
+ def clean_up_tokenization(out_string: str) -> str:
307
+ return out_string
308
+
309
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
310
+ text = ''.join(tokens)
311
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', 'replace')
312
+ return text
313
+
314
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
+ if not os.path.isdir(save_directory):
316
+ raise ValueError(f"vocabulary path ({save_directory}) should be a directory")
317
+ out_vocab_file = os.path.join(
318
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
319
+ )
320
+
321
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
322
+ copyfile(self.vocab_file, out_vocab_file)
323
+
324
+ return (out_vocab_file,)
325
+
326
+
327
+
328
+ def apply_chat_template(
329
+ self, conversation, tools: Optional[list[dict]] = None,
330
+ tokenize: bool = False,
331
+ add_generation_prompt: bool = True,
332
+ **kwargs
333
+ ):
334
+ tools = deep_sort_dict(tools)
335
+ return super().apply_chat_template(conversation,
336
+ tools=tools,
337
+ tokenize=tokenize,
338
+ add_generation_prompt=add_generation_prompt,
339
+ **kwargs)
340
+
341
+
342
+ def deep_sort_dict(obj: Any) -> Any:
343
+ if isinstance(obj, dict):
344
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
345
+ if isinstance(obj, list):
346
+ return [deep_sort_dict(item) for item in obj]
347
+ return obj
tokenizer_config.json ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "163584": {
4
+ "content": "[BOS]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "163585": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "163586": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "163587": {
28
+ "content": "<|im_user|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "163588": {
36
+ "content": "<|im_assistant|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "163590": {
44
+ "content": "<|start_header_id|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "163591": {
52
+ "content": "<|end_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "163593": {
60
+ "content": "[EOT]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "163594": {
68
+ "content": "<|im_system|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "163595": {
76
+ "content": "<|tool_calls_section_begin|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "163596": {
84
+ "content": "<|tool_calls_section_end|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "163597": {
92
+ "content": "<|tool_call_begin|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "163598": {
100
+ "content": "<|tool_call_argument_begin|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "163599": {
108
+ "content": "<|tool_call_end|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "163601": {
116
+ "content": "<|im_middle|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "163838": {
124
+ "content": "[UNK]",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "163839": {
132
+ "content": "[PAD]",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ }
139
+ },
140
+ "additional_special_tokens": [
141
+ "<|im_end|>",
142
+ "<|im_user|>",
143
+ "<|im_assistant|>",
144
+ "<|start_header_id|>",
145
+ "<|end_header_id|>",
146
+ "[EOT]",
147
+ "<|im_system|>",
148
+ "<|im_middle|>"
149
+ ],
150
+ "bos_token": "[BOS]",
151
+ "clean_up_tokenization_spaces": false,
152
+ "eos_token": "[EOS]",
153
+ "extra_special_tokens": {},
154
+ "model_max_length": 1000000000000000019884624838656,
155
+ "pad_token": "[PAD]",
156
+ "tokenizer_class": "TikTokenTokenizer",
157
+ "unk_token": "[UNK]",
158
+ "auto_map": {
159
+ "AutoTokenizer": [
160
+ "tokenization_kimi.TikTokenTokenizer",
161
+ null
162
+ ]
163
+ }
164
+ }