Darsala commited on
Commit
9412f10
·
verified ·
1 Parent(s): b05e960

Upload encoder_decoder_tokenizer.py

Browse files
Files changed (1) hide show
  1. encoder_decoder_tokenizer.py +470 -0
encoder_decoder_tokenizer.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder-Decoder Tokenizer Implementations
3
+
4
+ Provides tokenizer implementations for encoder-decoder models.
5
+ """
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from pathlib import Path
10
+ from overrides import overrides
11
+ from typing import Dict, Any, Tuple, Union, List, Optional, overload
12
+ from datasets import Dataset, DatasetDict
13
+ from transformers.tokenization_utils_base import (
14
+ AddedToken, # type: ignore
15
+ BatchEncoding,
16
+ EncodedInput,
17
+ EncodedInputPair,
18
+ PreTokenizedInput,
19
+ PreTokenizedInputPair,
20
+ TextInput,
21
+ TextInputPair,
22
+ TruncationStrategy,
23
+ )
24
+ from transformers.utils import logging
25
+ from transformers import AutoTokenizer
26
+ from transformers.utils.generic import PaddingStrategy, TensorType
27
+ from transformers.tokenization_utils import PreTrainedTokenizer
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers import EncoderDecoderModel
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ SPIECE_UNDERLINE = "▁"
34
+
35
+
36
+ class EncoderDecoderTokenizer(PreTrainedTokenizer):
37
+ def __init__(self, encoder_tokenizer_path, decoder_tokenizer_path, **kwargs):
38
+ self.encoder: PreTrainedTokenizer = AutoTokenizer.from_pretrained(encoder_tokenizer_path)
39
+ self.decoder: PreTrainedTokenizer = AutoTokenizer.from_pretrained(decoder_tokenizer_path)
40
+ self.current_tokenizer = self.encoder
41
+ self._decode_use_source_tokenizer = False
42
+
43
+ if self.decoder.eos_token is None:
44
+ self.decoder.eos_token = self.decoder.sep_token
45
+
46
+ if self.encoder.eos_token is None:
47
+ self.encoder.eos_token = self.encoder.sep_token
48
+
49
+ if self.encoder.pad_token is None:
50
+ self.encoder.pad_token = self.encoder.eos_token
51
+ if self.decoder.pad_token is None:
52
+ self.decoder.pad_token = self.decoder.eos_token
53
+
54
+ if self.encoder.bos_token is None:
55
+ self.encoder.bos_token = self.encoder.cls_token
56
+ if self.decoder.bos_token is None:
57
+ self.decoder.bos_token = self.decoder.cls_token
58
+
59
+ self._pad_token = self.encoder.pad_token
60
+ self._unk_token = self.encoder.unk_token
61
+ self._bos_token = self.encoder.bos_token
62
+ self._eos_token = self.encoder.eos_token
63
+ self._sep_token = self.encoder.sep_token
64
+ self._cls_token = self.encoder.cls_token
65
+ self._mask_token = self.encoder.mask_token
66
+ self.decoder_pad_token = self.decoder.pad_token
67
+ self.decoder_unk_token = self.decoder.unk_token
68
+ self.decoder_bos_token = self.decoder.bos_token
69
+ self.decoder_eos_token = self.decoder.eos_token
70
+ self.decoder_sep_token = self.decoder.sep_token
71
+ self.decoder_cls_token = self.decoder.cls_token
72
+ self.decoder_mas_token = self.decoder.mask_token
73
+
74
+ self.decoder_pad_token_id = self.decoder.pad_token_id
75
+ self.decoder_unk_token_id = self.decoder.unk_token_id
76
+ self.decoder_bos_token_id = self.decoder.bos_token_id
77
+ self.decoder_eos_token_id = self.decoder.eos_token_id
78
+ self.decoder_sep_token_id = self.decoder.sep_token_id
79
+ self.decoder_cls_token_id = self.decoder.cls_token_id
80
+ self.decoder_mas_token_id = self.decoder.mask_token_id
81
+ self._additional_special_tokens = []
82
+
83
+ @property
84
+ def is_fast(self) -> bool:
85
+ return self.current_tokenizer.is_fast
86
+
87
+ @property
88
+ def vocab_size(self) -> int:
89
+ """
90
+ `int`: Size of the base vocabulary (without the added tokens).
91
+ """
92
+ return self.current_tokenizer.vocab_size
93
+
94
+ @property
95
+ def added_tokens_encoder(self) -> Dict[str, int]:
96
+ """
97
+ Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
98
+ optimisation in `self._added_tokens_encoder` for the slow tokenizers.
99
+ """
100
+ return self.current_tokenizer.added_tokens_encoder
101
+
102
+ @property
103
+ def added_tokens_decoder(self) -> Dict[int, AddedToken]:
104
+ """
105
+ Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
106
+
107
+ Returns:
108
+ `Dict[str, int]`: The added tokens.
109
+ """
110
+ return self.current_tokenizer.added_tokens_decoder
111
+
112
+ @added_tokens_decoder.setter
113
+ def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> None:
114
+ self.current_tokenizer.added_tokens_decoder = value
115
+
116
+ def get_added_vocab(self) -> Dict[str, int]:
117
+ """
118
+ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
119
+ the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
120
+ something we should change.
121
+
122
+ Returns:
123
+ `Dict[str, int]`: The added tokens.
124
+ """
125
+ return self._added_tokens_encoder
126
+
127
+ def __len__(self):
128
+ """
129
+ Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
130
+ there is a hole in the vocab, we will add tokenizers at a wrong index.
131
+ """
132
+ return len(set(self.get_vocab().keys()))
133
+
134
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
135
+ """
136
+ Returns the number of added tokens when encoding a sequence with special tokens.
137
+
138
+ <Tip>
139
+
140
+ This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
141
+ this inside your training loop.
142
+
143
+ </Tip>
144
+
145
+ Args:
146
+ pair (`bool`, *optional*, defaults to `False`):
147
+ Whether the number of added tokens should be computed in the case of a sequence pair or a single
148
+ sequence.
149
+
150
+ Returns:
151
+ `int`: Number of special tokens added to sequences.
152
+ """
153
+ return self.current_tokenizer.num_special_tokens_to_add(pair)
154
+
155
+ def tokenize(self, text: TextInput, **kwargs):
156
+ """
157
+ Converts a string in a sequence of tokens, using the tokenizer.
158
+
159
+ Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
160
+ (BPE/SentencePieces/WordPieces). Takes care of added tokens.
161
+
162
+ Args:
163
+ text (`str`):
164
+ The sequence to be encoded.
165
+ **kwargs (additional keyword arguments):
166
+ Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
167
+
168
+ Returns:
169
+ `List[str]`: The list of tokens.
170
+ """
171
+ return self.decoder.tokenize(text, **kwargs)
172
+
173
+ def _tokenize(self, text, **kwargs):
174
+ """
175
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
176
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
177
+
178
+ Do NOT take care of added tokens.
179
+ """
180
+ raise self.decoder._tokenize(text, **kwargs)
181
+
182
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
183
+ """
184
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
185
+ vocabulary.
186
+
187
+ Args:
188
+ tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
189
+
190
+ Returns:
191
+ `int` or `List[int]`: The token id or list of token ids.
192
+ """
193
+ return self.current_tokenizer.convert_tokens_to_ids(tokens)
194
+
195
+ def _convert_token_to_id_with_added_voc(self, token):
196
+ return self.current_tokenizer._convert_token_to_id_with_added_voc(token)
197
+
198
+ def _convert_token_to_id(self, token):
199
+ return self.current_tokenizer._convert_token_to_id(token)
200
+
201
+ def encode(self, *args, **kwargs):
202
+ return self.current_tokenizer.encode(*args, **kwargs)
203
+
204
+ def _batch_encode_plus(
205
+ self,
206
+ batch_text_or_text_pairs: Union[
207
+ List[TextInput],
208
+ List[TextInputPair],
209
+ List[PreTokenizedInput],
210
+ List[PreTokenizedInputPair],
211
+ List[EncodedInput],
212
+ List[EncodedInputPair],
213
+ ],
214
+ add_special_tokens: bool = True,
215
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
216
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
217
+ max_length: Optional[int] = None,
218
+ stride: int = 0,
219
+ is_split_into_words: bool = False,
220
+ pad_to_multiple_of: Optional[int] = None,
221
+ return_tensors: Optional[Union[str, TensorType]] = None,
222
+ return_token_type_ids: Optional[bool] = None,
223
+ return_attention_mask: Optional[bool] = None,
224
+ return_overflowing_tokens: bool = False,
225
+ return_special_tokens_mask: bool = False,
226
+ return_offsets_mapping: bool = False,
227
+ return_length: bool = False,
228
+ verbose: bool = True,
229
+ **kwargs,
230
+ ) -> BatchEncoding:
231
+ return self.current_tokenizer._batch_encode_plus(batch_text_or_text_pairs=batch_text_or_text_pairs,
232
+ add_special_tokens=add_special_tokens,
233
+ padding_strategy=padding_strategy,
234
+ truncation_strategy=truncation_strategy,
235
+ max_length=max_length,
236
+ stride=stride,
237
+ is_split_into_words=is_split_into_words,
238
+ pad_to_multiple_of=pad_to_multiple_of,
239
+ return_tensors=return_tensors,
240
+ return_token_type_ids=return_token_type_ids,
241
+ return_attention_mask=return_attention_mask,
242
+ return_overflowing_tokens=return_overflowing_tokens,
243
+ return_special_tokens_mask=return_special_tokens_mask,
244
+ return_offsets_mapping=return_offsets_mapping,
245
+ return_length=return_length,
246
+ verbose=verbose,
247
+ **kwargs,
248
+ )
249
+
250
+ def prepare_for_tokenization(
251
+ self, text: str, is_split_into_words: bool = False, **kwargs
252
+ ) -> Tuple[str, Dict[str, Any]]:
253
+ """
254
+ Performs any necessary transformations before tokenization.
255
+
256
+ This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
257
+ `kwargs` at the end of the encoding process to be sure all the arguments have been used.
258
+
259
+ Args:
260
+ text (`str`):
261
+ The text to prepare.
262
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
263
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
264
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
265
+ which it will tokenize. This is useful for NER or token classification.
266
+ kwargs (`Dict[str, Any]`, *optional*):
267
+ Keyword arguments to use for the tokenization.
268
+
269
+ Returns:
270
+ `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
271
+ """
272
+ return self.current_tokenizer.prepare_for_tokenization(text, is_split_into_words, **kwargs)
273
+
274
+ def get_special_tokens_mask(
275
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
276
+ ) -> List[int]:
277
+ """
278
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
279
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
280
+
281
+ Args:
282
+ token_ids_0 (`List[int]`):
283
+ List of ids of the first sequence.
284
+ token_ids_1 (`List[int]`, *optional*):
285
+ List of ids of the second sequence.
286
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
287
+ Whether or not the token list is already formatted with special tokens for the model.
288
+
289
+ Returns:
290
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
291
+ """
292
+
293
+ return self.current_tokenizer.get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens)
294
+
295
+ @overload
296
+ def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str:
297
+ return self.current_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens)
298
+
299
+ @overload
300
+ def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
301
+ return self.current_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens)
302
+
303
+ def convert_ids_to_tokens(
304
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
305
+ ) -> Union[str, List[str]]:
306
+ """
307
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
308
+ added tokens.
309
+
310
+ Args:
311
+ ids (`int` or `List[int]`):
312
+ The token id (or token ids) to convert to tokens.
313
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
314
+ Whether or not to remove special tokens in the decoding.
315
+
316
+ Returns:
317
+ `str` or `List[str]`: The decoded token(s).
318
+ """
319
+ return self.current_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens)
320
+
321
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
322
+ return self.current_tokenizer.convert_tokens_to_string(tokens)
323
+
324
+ def decode(
325
+ self,
326
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"],
327
+ skip_special_tokens: bool = False,
328
+ clean_up_tokenization_spaces: Optional[bool] = None,
329
+ **kwargs,
330
+ ) -> str:
331
+ return self.decoder.decode(token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
332
+
333
+ @overrides
334
+ def __call__(self, text, text_target=None, *args, **kwargs):
335
+ if isinstance(text, str):
336
+ text = text + self.eos_token
337
+ else:
338
+ text = [i + self.eos_token for i in text]
339
+ results = self.encoder(text, *args, **kwargs)
340
+ if text_target:
341
+ tmp = self.decoder(text_target, *args, **kwargs)
342
+ results['labels'] = tmp['input_ids']
343
+ results['labels'][results['labels'] == self.decoder.pad_token_id] = -100
344
+ results['decoder_attention_mask'] = tmp['attention_mask']
345
+ return results
346
+
347
+ def _decode(
348
+ self,
349
+ token_ids: List[int],
350
+ skip_special_tokens: bool = False,
351
+ clean_up_tokenization_spaces: Optional[bool] = None,
352
+ spaces_between_special_tokens: bool = True,
353
+ **kwargs,
354
+ ) -> str:
355
+ return self.decoder._decode(token_ids,
356
+ skip_special_tokens,
357
+ clean_up_tokenization_spaces,
358
+ spaces_between_special_tokens)
359
+
360
+ def save_pretrained(
361
+ self,
362
+ save_directory: Union[str, os.PathLike],
363
+ legacy_format: Optional[bool] = None,
364
+ filename_prefix: Optional[str] = None,
365
+ push_to_hub: bool = False,
366
+ **kwargs,
367
+ ) -> None:
368
+ encoder_path = Path(save_directory) / Path("encoder")
369
+ decoder_path = Path(save_directory) / Path("decoder")
370
+ self.encoder.save_pretrained(encoder_path, legacy_format, filename_prefix, push_to_hub, **kwargs)
371
+ self.decoder.save_pretrained(decoder_path, legacy_format, filename_prefix, push_to_hub, **kwargs)
372
+
373
+ @classmethod
374
+ def from_pretrained(
375
+ cls,
376
+ pretrained_model_name_or_path: Union[str, os.PathLike],
377
+ *init_inputs,
378
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
379
+ force_download: bool = False,
380
+ local_files_only: bool = False,
381
+ token: Optional[Union[str, bool]] = None,
382
+ revision: str = "main",
383
+ **kwargs,
384
+ ):
385
+ encoder_path = Path(pretrained_model_name_or_path) / Path("encoder")
386
+ decoder_path = Path(pretrained_model_name_or_path) / Path("decoder")
387
+
388
+ return EncoderDecoderTokenizer(encoder_path, decoder_path)
389
+
390
+ def _switch_to_target_mode(self):
391
+ self.current_encoder = self.decoder
392
+
393
+ def _switch_to_input_mode(self):
394
+ self.current_tokenizer = self.encoder
395
+
396
+ @property
397
+ def pad_token_id(self) -> Any:
398
+ """Return pad token ID from current tokenizer."""
399
+ return self.current_tokenizer.pad_token_id
400
+
401
+ @property
402
+ def unk_token_id(self) -> Any:
403
+ """Return unk token ID from current tokenizer."""
404
+ return self.current_tokenizer.unk_token_id
405
+
406
+ @property
407
+ def bos_token_id(self) -> Any:
408
+ """Return bos token ID from current tokenizer."""
409
+ return self.current_tokenizer.bos_token_id
410
+
411
+ @property
412
+ def eos_token_id(self) -> Any:
413
+ """Return eos token ID from current tokenizer."""
414
+ return self.current_tokenizer.eos_token_id
415
+
416
+ @property
417
+ def sep_token_id(self) -> Any:
418
+ """Return sep token ID from current tokenizer."""
419
+ return self.current_tokenizer.sep_token_id
420
+
421
+ @property
422
+ def cls_token_id(self) -> Any:
423
+ """Return cls token ID from current tokenizer."""
424
+ return self.current_tokenizer.cls_token_id
425
+
426
+ @property
427
+ def mask_token_id(self) -> Any:
428
+ """Return mask token ID from current tokenizer."""
429
+ return self.current_tokenizer.mask_token_id
430
+
431
+ def get_vocab(self) -> Dict[str, int]:
432
+ """
433
+ Returns the vocabulary as a dictionary of token to indices.
434
+ """
435
+ return self.current_tokenizer.get_vocab()
436
+
437
+ @property
438
+ def pad_token(self) -> Any:
439
+ """Return pad token from current tokenizer."""
440
+ return self.current_tokenizer.pad_token
441
+
442
+ @property
443
+ def unk_token(self) -> Any:
444
+ """Return unk token from current tokenizer."""
445
+ return self.current_tokenizer.unk_token
446
+
447
+ @property
448
+ def bos_token(self) -> Any:
449
+ """Return bos token from current tokenizer."""
450
+ return self.current_tokenizer.bos_token
451
+
452
+ @property
453
+ def eos_token(self) -> Any:
454
+ """Return eos token from current tokenizer."""
455
+ return self.current_tokenizer.eos_token
456
+
457
+ @property
458
+ def sep_token(self) -> Any:
459
+ """Return sep token from current tokenizer."""
460
+ return self.current_tokenizer.sep_token
461
+
462
+ @property
463
+ def cls_token(self) -> Any:
464
+ """Return cls token from current tokenizer."""
465
+ return self.current_tokenizer.cls_token
466
+
467
+ @property
468
+ def mask_token(self) -> Any:
469
+ """Return mask token from current tokenizer."""
470
+ return self.current_tokenizer.mask_token