noamwies commited on
Commit
5a7cdd2
·
1 Parent(s): 051951e

Upload create_miniature_model.py

Browse files
Files changed (1) hide show
  1. create_miniature_model.py +46 -0
create_miniature_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import tokenizers
4
+ import torch
5
+ import transformers
6
+
7
+
8
+ def shrink_vocab(tokenizer, new_vocab_size):
9
+ tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
10
+ vocab = tokenizer_json["model"]["vocab"]
11
+ if tokenizer_json["model"]["type"] == "BPE":
12
+ new_vocab = { token: i for token, i in vocab.items() if i < new_vocab_size }
13
+ merges = tokenizer_json["model"]["merges"]
14
+ new_merges = []
15
+ for i in range(len(merges)):
16
+ if len( merges[i].split()) == 2:
17
+ a, b = merges[i].split()
18
+ else:
19
+ print('skip')
20
+ continue
21
+ new_token = "".join((a, b))
22
+ if a in new_vocab and b in new_vocab and new_token in new_vocab:
23
+ new_merges.append(merges[i])
24
+ tokenizer_json["model"]["merges"] = new_merges
25
+ elif tokenizer_json["model"]["type"] == "Unigram":
26
+ new_vocab = vocab[:new_vocab_size]
27
+ elif tokenizer_json["model"]["type"] == "WordPiece" or tokenizer_json["model"]["type"] == "WordLevel":
28
+ new_vocab = { token: i for token, i in vocab.items() if i < new_vocab_size }
29
+ else:
30
+ raise ValueError(f"don't know how to handle {tokenizer_json['model']['type']}")
31
+ tokenizer_json["model"]["vocab"] = new_vocab
32
+ tokenizer._tokenizer = tokenizers.Tokenizer.from_str(json.dumps(tokenizer_json))
33
+
34
+
35
+ def main():
36
+ tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
37
+ shrink_vocab(tokenizer, new_vocab_size=2000)
38
+ tokenizer.save_pretrained(".")
39
+
40
+ config = transformers.AutoConfig.from_pretrained('noamwies/llama-test-gqa-with-better-transformer')
41
+ model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=config.torch_dtype)
42
+ torch.save(model.state_dict(), 'pytorch_model.bin')
43
+
44
+
45
+ if __name__ == '__main__':
46
+ main()