Map-NEO / data_prep.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
# MAP-NEO Mini: Data Preprocessing Pipeline
# Downloads Matrix dataset, filters to English, tokenizes and packs sequences
import json
import os
import itertools
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer
import langdetect
from tqdm import tqdm
import argparse
class DataPreprocessor:
def __init__(self, output_dir="data", seq_length=1024):
self.output_dir = Path(output_dir)
self.seq_length = seq_length
self.setup_directories()
def setup_directories(self):
"""Create necessary directories"""
dirs = ["shards", "processed", "tokens"]
for d in dirs:
(self.output_dir / d).mkdir(parents=True, exist_ok=True)
def download_refinedweb_sample(self, num_docs=100000):
"""Download a sample from RefinedWeb dataset"""
print(f"Downloading {num_docs} documents from RefinedWeb...")
raw_path = self.output_dir / "shards" / "refinedweb_sample_raw.jsonl"
try:
# Load RefinedWeb dataset
ds = load_dataset("tiiuae/falcon-refinedweb", split="train", streaming=True)
downloaded = 0
with open(raw_path, "w", encoding="utf-8") as f:
for row in tqdm(itertools.islice(ds, num_docs), total=num_docs):
# RefinedWeb has 'content' field instead of 'text'
text = row.get("content", "").strip()
if text and len(text) > 100: # Quality filter
f.write(json.dumps({"text": text}, ensure_ascii=False) + "\n")
downloaded += 1
if downloaded >= num_docs:
break
print(f"Raw RefinedWeb data saved to: {raw_path}")
print(f"Downloaded {downloaded} high-quality documents")
return raw_path
except Exception as e:
print(f"Error downloading RefinedWeb: {e}")
print("Falling back to Matrix dataset...")
return self.download_matrix_sample_fallback(num_docs)
def download_matrix_sample_fallback(self, num_docs=10000):
"""Download a sample from MAP-NEO Matrix dataset"""
print(f"Downloading {num_docs} documents from Matrix dataset...")
raw_path = self.output_dir / "shards" / "matrix_sample_raw.jsonl"
ds = load_dataset("m-a-p/Matrix", split="train", streaming=True)
with open(raw_path, "w", encoding="utf-8") as f:
for i, row in enumerate(tqdm(itertools.islice(ds, num_docs), total=num_docs)):
text = row.get("text") or row.get("content") or ""
if text.strip():
f.write(json.dumps({"text": text}, ensure_ascii=False) + "\n")
print(f"Raw data saved to: {raw_path}")
return raw_path
# def filter_english(self, input_path):
# """Filter documents to English only"""
# print("Filtering documents to English only...")
# input_path = Path(input_path)
# output_path = self.output_dir / "processed" / "matrix_english.jsonl"
# english_count = 0
# total_count = 0
# with open(input_path, "r", encoding="utf-8") as infile, \
# open(output_path, "w", encoding="utf-8") as outfile:
# for line in tqdm(infile):
# total_count += 1
# try:
# obj = json.loads(line)
# text = obj["text"]
# # Skip very short texts
# if len(text) < 50:
# continue
# # Detect language
# if langdetect.detect(text) == "en":
# outfile.write(json.dumps(obj, ensure_ascii=False) + "\n")
# english_count += 1
# except Exception:
# continue
# print(f"Filtered {english_count}/{total_count} documents to English")
# print(f"English data saved to: {output_path}")
# return output_path
def filter_refinedweb_quality(self, input_path):
"""Enhanced quality filtering for RefinedWeb data"""
print("Applying enhanced quality filtering for RefinedWeb...")
input_path = Path(input_path)
output_path = self.output_dir / "processed" / "refinedweb_filtered.jsonl"
filtered_count = 0
total_count = 0
with open(input_path, "r", encoding="utf-8") as infile, \
open(output_path, "w", encoding="utf-8") as outfile:
for line in tqdm(infile, desc="Quality filtering"):
total_count += 1
try:
obj = json.loads(line)
text = obj["text"]
# Enhanced quality filters for web data
if self.is_high_quality_web_text(text):
outfile.write(json.dumps(obj, ensure_ascii=False) + "\n")
filtered_count += 1
except Exception:
continue
print(f"Filtered {filtered_count}/{total_count} documents for quality")
print(f"Filtered data saved to: {output_path}")
return output_path
def is_high_quality_web_text(self, text):
"""Check if web text meets quality standards"""
# Length checks
if len(text) < 200 or len(text) > 10000:
return False
# Language detection
try:
if langdetect.detect(text) != "en":
return False
except:
return False
# Content quality checks
words = text.split()
if len(words) < 50: # Too short
return False
# Check for spam/low-quality indicators
spam_indicators = ['click here', 'buy now', 'free download', '###', '***']
text_lower = text.lower()
spam_count = sum(1 for indicator in spam_indicators if indicator in text_lower)
if spam_count > 2:
return False
# Check for reasonable sentence structure
sentences = text.split('.')
if len(sentences) < 3: # Too few sentences
return False
return True
def tokenize_and_pack(self, input_path, tokenizer_name="gpt2"):
"""Tokenize documents and pack into fixed-length sequences"""
print(f"Tokenizing with {tokenizer_name} and packing to {self.seq_length} tokens...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
input_path = Path(input_path)
output_path = self.output_dir / "tokens" / f"packed_{self.seq_length}.txt"
buffer = []
sequences_written = 0
total_tokens = 0
with open(input_path, "r", encoding="utf-8") as infile, \
open(output_path, "w", encoding="utf-8") as outfile:
for line in tqdm(infile, desc="Processing documents"):
try:
text = json.loads(line)["text"]
# Tokenize
tokens = tokenizer.encode(text, add_special_tokens=False)
# Add to buffer with EOS token
buffer.extend(tokens + [tokenizer.eos_token_id])
total_tokens += len(tokens) + 1
# Pack complete sequences
while len(buffer) >= self.seq_length:
sequence = buffer[:self.seq_length]
buffer = buffer[self.seq_length:]
# Write sequence as space-separated integers
outfile.write(" ".join(map(str, sequence)) + "\n")
sequences_written += 1
except Exception as e:
continue
print(f"Created {sequences_written} sequences of {self.seq_length} tokens each")
print(f"Total tokens processed: {total_tokens:,}")
print(f"Packed data saved to: {output_path}")
# Save tokenizer for later use
tokenizer_path = self.output_dir / "tokenizer"
tokenizer.save_pretrained(tokenizer_path)
print(f"Tokenizer saved to: {tokenizer_path}")
return output_path, tokenizer_path
def main():
parser = argparse.ArgumentParser(description="Preprocess MAP-NEO training data")
parser.add_argument("--num_docs", type=int, default=10000,
help="Number of documents to download")
parser.add_argument("--seq_length", type=int, default=1024,
help="Sequence length for packing")
parser.add_argument("--tokenizer", type=str, default="gpt2",
help="Tokenizer to use")
parser.add_argument("--output_dir", type=str, default="data",
help="Output directory")
args = parser.parse_args()
# Initialize preprocessor
preprocessor = DataPreprocessor(args.output_dir, args.seq_length)
# Run pipeline
print("Starting MAP-NEO data preprocessing pipeline...")
# Step 1: Download sample
raw_path = preprocessor.download_refinedweb_sample(args.num_docs)
# Step 2: Filter to English
filtered_path = preprocessor.filter_refinedweb_quality(raw_path)
# Step 3: Tokenize and pack
packed_path, tokenizer_path = preprocessor.tokenize_and_pack(
filtered_path, args.tokenizer
)
print("\n" + "="*50)
print("Data preprocessing complete!")
print(f"Packed sequences: {packed_path}")
print(f"Tokenizer: {tokenizer_path}")
print("="*50)
if __name__ == "__main__":
main()