Gemma-2B-GSM8K-CoT-LoRA (Fine-tuned on Google Colab Free Tier)
This repository contains my Gemma-2B-IT Math reasoning-enhanced model, fine-tuned on the GSM8K dataset using LoRA + Chain-of-Thought (CoT) formatting, trained entirely in Google Colab Free Tier.
The project required daily session limits, so training happened across multiple days with mandatory cooldowns — even though the final accumulated training time is 19 hours, the full effort spanned more than one month including dataset preparation, debugging, and re-running sessions.
Model Summary
- Base Model:
google/gemma-2b-it - Fine-Tuning Method: LoRA (r=64), CoT-style prompting
- Dataset: GSM8K (main split)
- Training Style: Single-stage curriculum (GSM8K only)
- Platform: Google Colab Free Tier GPU
- Total Effective Training Time: ~19 hours
- Total Wall-Clock Effort: >1 month (due to free-tier 24-hour cooldown cycles)
Training Configuration
These are the exact hyperparameters used in training:
DATASETS = ["gsm8k"]
PER_DEVICE_BATCH = 1
GRAD_ACCUM = 16
EPOCHS_PER_STAGE = 21
LR = 2e-5
MAX_SEQ_LEN = 512
load_in_4bit = True
bnb_4bit_compute_dtype = bfloat16
optimizer = paged_adamw_32bit
fp16 = True
The model was trained using step-by-step Chain-of-Thought formatting:
"Let's think step by step:\n{question}"
Labels were masked so the model only learns from the reasoning steps + final answer, not the instruction.
Performance (GSM8K Evaluation)
The model was evaluated on 1000 GSM8K test samples using deterministic generation
(do_sample = False, max_new_tokens = 320).
Although the objective was improved reasoning structure (not accuracy), final-answer accuracy still improved measurably.
Accuracy Comparison
| Model | Accuracy | Score |
|---|---|---|
| Base Gemma-2B-IT | 100 / 1000 | 10.0% |
| Fine-tuned (LoRA Merged) | 122 / 1000 | 12.2% |
Improvement
- Absolute Gain: +2.2 percentage points
- Relative Improvement: +22% over the base model
- Indicates better reasoning trace quality and slightly better correctness.
Evaluation Method
Ground-truth answer extracted from the GSM8K "answer" field
Prediction considered correct if the final numeric value appears in the model response
Uses the same CoT wrapper used during training:
"Let's think step by step:\n{instruction}"
Model Comparison: Base vs Fine-Tuned (GSM8K Math Reasoning)
This repository contains my fine-tuned Gemma-2B model trained with LoRA + CoT (Chain-of-Thought) for better GSM8K math reasoning.
Below is the script to evaluate and compare:
- google/gemma-2b-it (Base)
- Ponmurugaiya72/gemma-2b-math-lora-cot (Fine-tuned merged FP16 model)
Run Evaluation on GSM8K
import os, re, csv, torch, sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
# Your HF model
HF_MERGED_MODEL = "Ponmurugaiya72/gemma-2b-math-lora-cot"
BASE_MODEL = "google/gemma-2b-it"
NUM_SAMPLES = 1000
MAX_NEW_TOKENS = 320
CSV_PATH = "gsm8k_eval.csv"
# -------------------
# Load tokenizer
# -------------------
tokenizer = AutoTokenizer.from_pretrained(HF_MERGED_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# -------------------
# Load base model
# -------------------
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.float16, device_map="auto"
)
# -------------------
# Load fine-tuned merged FP16 model
# -------------------
merged = AutoModelForCausalLM.from_pretrained(
HF_MERGED_MODEL, torch_dtype=torch.float16, device_map="auto"
)
# -------------------
# Load GSM8K test set
# -------------------
ds = load_dataset("gsm8k", "main")
test = ds["test"].select(range(NUM_SAMPLES))
# -------------------
# Prompt + helpers
# -------------------
def build_prompt(q):
return (
"Solve step-by-step and at the end output:\n"
"Final Answer: #### <number>\n\n"
f"Problem: {q}\n\nLet's reason:\n"
)
def extract_final_answer(text):
m = re.search(r"Final Answer:\s*#+\s*([0-9\.\-]+)", text)
if m:
return m.group(1).strip()
nums = re.findall(r"[-]?\d+\.?\d*", text)
return nums[-1] if nums else None
def generate(model, prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(out[0], skip_special_tokens=True)
# -------------------
# Evaluation loop
# -------------------
base_correct = merged_correct = 0
for i, sample in enumerate(test):
q, a = sample["question"], sample["answer"]
gt = re.search(r"####\s*([0-9\.\-]+)", a).group(1)
prompt = build_prompt(q)
base_out = generate(base, prompt)
merged_out = generate(merged, prompt)
base_pred = extract_final_answer(base_out)
merged_pred = extract_final_answer(merged_out)
if base_pred == gt:
base_correct += 1
if merged_pred == gt:
merged_correct += 1
print(f"[{i+1}/{NUM_SAMPLES}] Base={base_pred}, Merged={merged_pred}, GT={gt}")
print("Base Accuracy :", base_correct / NUM_SAMPLES)
print("Merged Accuracy :", merged_correct / NUM_SAMPLES)
Example Inference (Single Question)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "Ponmurugaiya72/gemma-2b-math-lora-cot"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Let's think step by step: A train travels 60 miles in 1.5 hours. What is its speed?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=256, do_sample=False)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Acknowledgements
- Trained fully on Google Colab Free Tier
- Thanks to open-source tools: HuggingFace, PEFT, Google Colab, Transformers
- Downloads last month
- 44
Model tree for Ponmurugaiya72/gemma-2b-math-lora-cot
Base model
google/gemma-2b-it