metascroy commited on
Commit
b9b456c
·
verified ·
1 Parent(s): aaf5763

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +183 -0
README.md CHANGED
@@ -21,3 +21,186 @@ language:
21
  This mistral3 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
22
 
23
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  This mistral3 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
22
 
23
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
24
+
25
+
26
+ ```python
27
+ ################################################################################
28
+ # We first load the model for QAT using the mobile CPU friendly int8-int4 scheme
29
+ ################################################################################
30
+
31
+ from unsloth import FastVisionModel
32
+ from unsloth.chat_templates import (
33
+ get_chat_template,
34
+ )
35
+ import torch
36
+
37
+ MODEL_ID = "unsloth/Ministral-3-3B-Instruct-2512"
38
+ QAT_SCHEME = "int8-int4"
39
+
40
+ model, tokenizer = FastVisionModel.from_pretrained(
41
+ model_name = MODEL_ID,
42
+ max_seq_length = 2048,
43
+ dtype = torch.bfloat16,
44
+ load_in_4bit = False,
45
+ full_finetuning = True,
46
+ # ExecuTorch CPU quantization scheme
47
+ # Quantize embedding to 8-bits, and quantize linear layers to 4-bits
48
+ # with 8-bit dynamically quantized activations
49
+ qat_scheme = QAT_SCHEME,
50
+ )
51
+
52
+ print(model)
53
+
54
+ ################################################################################
55
+ # Data prep
56
+ ################################################################################
57
+
58
+ from datasets import load_dataset
59
+ dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
60
+
61
+ # Convert the dataset into a conversational format
62
+ instruction = "Write the LaTeX representation for this image."
63
+
64
+ def convert_to_conversation(sample):
65
+ conversation = [
66
+ { "role": "user",
67
+ "content" : [
68
+ {"type" : "text", "text" : instruction},
69
+ {"type" : "image", "image" : sample["image"]} ]
70
+ },
71
+ { "role" : "assistant",
72
+ "content" : [
73
+ {"type" : "text", "text" : sample["text"]} ]
74
+ },
75
+ ]
76
+ return { "messages" : conversation }
77
+
78
+ converted_dataset = [convert_to_conversation(sample) for sample in dataset]
79
+
80
+ print(converted_dataset[0])
81
+
82
+
83
+ ################################################################################
84
+ # Before finetuning
85
+ ################################################################################
86
+ FastVisionModel.for_inference(model) # Enable for inference!
87
+
88
+ image = dataset[2]["image"]
89
+ instruction = "Write the LaTeX representation for this image."
90
+
91
+ messages = [
92
+ {"role": "user", "content": [
93
+ {"type": "image"},
94
+ {"type": "text", "text": instruction}
95
+ ]}
96
+ ]
97
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
98
+ inputs = tokenizer(
99
+ image,
100
+ input_text,
101
+ add_special_tokens = False,
102
+ return_tensors = "pt",
103
+ ).to("cuda")
104
+
105
+ from transformers import TextStreamer
106
+ text_streamer = TextStreamer(tokenizer, skip_prompt = True)
107
+ _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 64,
108
+ use_cache = True, temperature = 1.5, min_p = 0.1)
109
+
110
+
111
+ ################################################################################
112
+ # Define trainer
113
+ ################################################################################
114
+
115
+ from unsloth.trainer import UnslothVisionDataCollator
116
+ from trl import SFTTrainer, SFTConfig
117
+ from unsloth import is_bf16_supported
118
+
119
+ trainer = SFTTrainer(
120
+ model = model,
121
+ tokenizer = tokenizer,
122
+ data_collator = UnslothVisionDataCollator(model, tokenizer), # Must use!
123
+ train_dataset = converted_dataset,
124
+ args = SFTConfig(
125
+ per_device_train_batch_size = 4,
126
+ gradient_accumulation_steps = 2,
127
+ warmup_steps = 5,
128
+ max_steps = 30,
129
+ # num_train_epochs = 1, # Set this instead of max_steps for full training runs
130
+ learning_rate = 3e-5,
131
+ logging_steps = 1,
132
+ optim = "adamw_8bit",
133
+ fp16 = not is_bf16_supported(), # Use fp16 if bf16 is not supported
134
+ bf16 = is_bf16_supported(), # Use bf16 if supported
135
+ weight_decay = 0.001,
136
+ lr_scheduler_type = "linear",
137
+ seed = 3407,
138
+ output_dir = "outputs",
139
+ report_to = "none",
140
+
141
+ # You MUST put the below items for vision finetuning:
142
+ remove_unused_columns = False,
143
+ dataset_text_field = "",
144
+ dataset_kwargs = {"skip_prepare_dataset": True},
145
+ max_length = 2048,
146
+ ),
147
+ )
148
+
149
+
150
+ ################################################################################
151
+ # Do fine tuning
152
+ ################################################################################
153
+ trainer_stats = trainer.train()
154
+
155
+ ################################################################################
156
+ # Inference after finetuning
157
+ ################################################################################
158
+ FastVisionModel.for_inference(model) # Enable for inference!
159
+
160
+ image = dataset[2]["image"]
161
+ instruction = "Write the LaTeX representation for this image."
162
+
163
+ messages = [
164
+ {"role": "user", "content": [
165
+ {"type": "image"},
166
+ {"type": "text", "text": instruction}
167
+ ]}
168
+ ]
169
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
170
+ inputs = tokenizer(
171
+ image,
172
+ input_text,
173
+ add_special_tokens = False,
174
+ return_tensors = "pt",
175
+ ).to("cuda")
176
+
177
+ from transformers import TextStreamer
178
+ text_streamer = TextStreamer(tokenizer, skip_prompt = True)
179
+ _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
180
+ use_cache = True, temperature = 1.5, min_p = 0.1)
181
+
182
+
183
+ # ################################################################################
184
+ # # Convert model to torchao format and save
185
+ # ################################################################################
186
+
187
+ from unsloth.models._utils import _convert_torchao_model
188
+ _convert_torchao_model(model)
189
+
190
+ model_name = MODEL_ID.split("/")[-1]
191
+ save_to = f"{model_name}-{QAT_SCHEME}-unsloth"
192
+
193
+ # Save locally
194
+ model.save_pretrained(save_to, safe_serialization=False)
195
+ tokenizer.save_pretrained(save_to)
196
+
197
+ # Or save to hub
198
+ from huggingface_hub import get_token, whoami
199
+ def _get_username():
200
+ token = get_token()
201
+ username = whoami(token=token)["name"]
202
+ return username
203
+ username = _get_username()
204
+ model.push_to_hub(f"{username}/{save_to}", safe_serialization=False)
205
+ tokenizer.push_to_hub(f"{username}/{save_to}")
206
+ ```