Gemma-2B-GSM8K-CoT-LoRA (Fine-tuned on Google Colab Free Tier)

This repository contains my Gemma-2B-IT Math reasoning-enhanced model, fine-tuned on the GSM8K dataset using LoRA + Chain-of-Thought (CoT) formatting, trained entirely in Google Colab Free Tier.

The project required daily session limits, so training happened across multiple days with mandatory cooldowns — even though the final accumulated training time is 19 hours, the full effort spanned more than one month including dataset preparation, debugging, and re-running sessions.


Model Summary

  • Base Model: google/gemma-2b-it
  • Fine-Tuning Method: LoRA (r=64), CoT-style prompting
  • Dataset: GSM8K (main split)
  • Training Style: Single-stage curriculum (GSM8K only)
  • Platform: Google Colab Free Tier GPU
  • Total Effective Training Time: ~19 hours
  • Total Wall-Clock Effort: >1 month (due to free-tier 24-hour cooldown cycles)

Training Configuration

These are the exact hyperparameters used in training:

DATASETS = ["gsm8k"]

PER_DEVICE_BATCH = 1
GRAD_ACCUM = 16
EPOCHS_PER_STAGE = 21
LR = 2e-5
MAX_SEQ_LEN = 512

load_in_4bit = True
bnb_4bit_compute_dtype = bfloat16
optimizer = paged_adamw_32bit
fp16 = True

The model was trained using step-by-step Chain-of-Thought formatting:

"Let's think step by step:\n{question}"

Labels were masked so the model only learns from the reasoning steps + final answer, not the instruction.


Performance (GSM8K Evaluation)

The model was evaluated on 1000 GSM8K test samples using deterministic generation (do_sample = False, max_new_tokens = 320).

Although the objective was improved reasoning structure (not accuracy), final-answer accuracy still improved measurably.

Accuracy Comparison

Model Accuracy Score
Base Gemma-2B-IT 100 / 1000 10.0%
Fine-tuned (LoRA Merged) 122 / 1000 12.2%

Improvement

  • Absolute Gain: +2.2 percentage points
  • Relative Improvement: +22% over the base model
  • Indicates better reasoning trace quality and slightly better correctness.

Evaluation Method

  • Ground-truth answer extracted from the GSM8K "answer" field

  • Prediction considered correct if the final numeric value appears in the model response

  • Uses the same CoT wrapper used during training:

    "Let's think step by step:\n{instruction}"
    

Model Comparison: Base vs Fine-Tuned (GSM8K Math Reasoning)

This repository contains my fine-tuned Gemma-2B model trained with LoRA + CoT (Chain-of-Thought) for better GSM8K math reasoning.

Below is the script to evaluate and compare:

  • google/gemma-2b-it (Base)
  • Ponmurugaiya72/gemma-2b-math-lora-cot (Fine-tuned merged FP16 model)

Run Evaluation on GSM8K

import os, re, csv, torch, sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# Your HF model
HF_MERGED_MODEL = "Ponmurugaiya72/gemma-2b-math-lora-cot"
BASE_MODEL = "google/gemma-2b-it"

NUM_SAMPLES = 1000
MAX_NEW_TOKENS = 320
CSV_PATH = "gsm8k_eval.csv"


# -------------------
# Load tokenizer
# -------------------
tokenizer = AutoTokenizer.from_pretrained(HF_MERGED_MODEL)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


# -------------------
# Load base model
# -------------------
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL, torch_dtype=torch.float16, device_map="auto"
)

# -------------------
# Load fine-tuned merged FP16 model
# -------------------
merged = AutoModelForCausalLM.from_pretrained(
    HF_MERGED_MODEL, torch_dtype=torch.float16, device_map="auto"
)


# -------------------
# Load GSM8K test set
# -------------------
ds = load_dataset("gsm8k", "main")
test = ds["test"].select(range(NUM_SAMPLES))


# -------------------
# Prompt + helpers
# -------------------
def build_prompt(q):
    return (
        "Solve step-by-step and at the end output:\n"
        "Final Answer: #### <number>\n\n"
        f"Problem: {q}\n\nLet's reason:\n"
    )

def extract_final_answer(text):
    m = re.search(r"Final Answer:\s*#+\s*([0-9\.\-]+)", text)
    if m:
        return m.group(1).strip()
    nums = re.findall(r"[-]?\d+\.?\d*", text)
    return nums[-1] if nums else None

def generate(model, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)


# -------------------
# Evaluation loop
# -------------------
base_correct = merged_correct = 0

for i, sample in enumerate(test):
    q, a = sample["question"], sample["answer"]

    gt = re.search(r"####\s*([0-9\.\-]+)", a).group(1)

    prompt = build_prompt(q)

    base_out = generate(base, prompt)
    merged_out = generate(merged, prompt)

    base_pred = extract_final_answer(base_out)
    merged_pred = extract_final_answer(merged_out)

    if base_pred == gt:
        base_correct += 1
    if merged_pred == gt:
        merged_correct += 1

    print(f"[{i+1}/{NUM_SAMPLES}] Base={base_pred}, Merged={merged_pred}, GT={gt}")


print("Base Accuracy   :", base_correct / NUM_SAMPLES)
print("Merged Accuracy :", merged_correct / NUM_SAMPLES)

Example Inference (Single Question)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "Ponmurugaiya72/gemma-2b-math-lora-cot"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompt = "Let's think step by step: A train travels 60 miles in 1.5 hours. What is its speed?"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=256, do_sample=False)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Acknowledgements

  • Trained fully on Google Colab Free Tier
  • Thanks to open-source tools: HuggingFace, PEFT, Google Colab, Transformers

Downloads last month
44
Safetensors
Model size
3B params
Tensor type
F16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Ponmurugaiya72/gemma-2b-math-lora-cot

Base model

google/gemma-2b-it
Finetuned
(102)
this model

Dataset used to train Ponmurugaiya72/gemma-2b-math-lora-cot