Map-NEO / conversation_data_prep.py
Austin207's picture
Upload folder using huggingface_hub
a683148 verified
# MAP-NEO Conversational Data Preprocessing Pipeline - FIXED VERSION
# Downloads conversational datasets, filters and formats for instruction fine-tuning
import json
import os
import itertools
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer
import langdetect
from tqdm import tqdm
import argparse
import random
from collections import defaultdict
class ConversationDataPreprocessor:
def __init__(self, output_dir="data", max_length=1024):
self.output_dir = Path(output_dir)
self.max_length = max_length
self.setup_directories()
def setup_directories(self):
"""Create necessary directories"""
dirs = ["conversation_raw", "conversation_processed", "conversation_final"]
for d in dirs:
(self.output_dir / d).mkdir(parents=True, exist_ok=True)
def download_conversational_data(self, dataset_name="OpenAssistant/oasst1", num_conversations=20000):
"""Download conversational dataset from HuggingFace"""
print(f"Downloading {num_conversations} conversations from {dataset_name}...")
raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
try:
# Load dataset
ds = load_dataset(dataset_name, split="train", streaming=True)
downloaded = 0
with open(raw_path, "w", encoding="utf-8") as f:
for row in tqdm(itertools.islice(ds, num_conversations), total=num_conversations):
# Save raw conversation data
f.write(json.dumps(row, ensure_ascii=False) + "\n")
downloaded += 1
print(f"Raw conversational data saved to: {raw_path}")
print(f"Downloaded {downloaded} conversation records")
return raw_path
except Exception as e:
print(f"Error downloading {dataset_name}: {e}")
print("Trying alternative dataset...")
return self.download_alternative_dataset(num_conversations)
def download_alternative_dataset(self, num_conversations=20000):
"""Try alternative conversational datasets if primary fails"""
alternative_datasets = [
"databricks/databricks-dolly-15k",
"tatsu-lab/alpaca",
"vicgalle/alpaca-gpt4"
]
for dataset_name in alternative_datasets:
try:
print(f"Trying {dataset_name}...")
raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
ds = load_dataset(dataset_name, split="train")
# Sample if dataset is too large
if len(ds) > num_conversations:
ds = ds.shuffle(seed=42).select(range(num_conversations))
with open(raw_path, "w", encoding="utf-8") as f:
for row in tqdm(ds):
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"Successfully downloaded {len(ds)} records from {dataset_name}")
return raw_path
except Exception as e:
print(f"Failed to download {dataset_name}: {e}")
continue
raise Exception("All conversational datasets failed to download")
def process_conversations(self, input_path, dataset_name="auto"):
"""Process raw conversational data into standard format"""
print("Processing conversations into standard format...")
input_path = Path(input_path)
# Detect dataset type from filename
if "OpenAssistant" in str(input_path) or "oasst" in str(input_path):
return self.process_openassistant_messages(input_path)
else:
return self.process_other_datasets(input_path)
def process_openassistant_messages(self, input_path):
"""Process OpenAssistant individual messages into conversation chains"""
print("๐Ÿš€ Processing OpenAssistant messages into conversations...")
# Load all messages
messages = []
print("Loading messages...")
with open(input_path, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Reading messages"):
try:
msg = json.loads(line)
# Filter for valid English messages
if (msg.get('lang') == 'en' and
not msg.get('deleted', False) and
msg.get('review_result', False) and
msg.get('text', '').strip()):
messages.append(msg)
except:
continue
print(f"Loaded {len(messages)} valid English messages")
# Group messages by conversation tree
trees = defaultdict(list)
for msg in messages:
tree_id = msg.get('message_tree_id')
if tree_id:
trees[tree_id].append(msg)
print(f"Found {len(trees)} conversation trees")
# Build conversation chains from each tree
conversations = []
for tree_id, tree_messages in tqdm(trees.items(), desc="Building conversations"):
# Create message lookup
msg_dict = {msg['message_id']: msg for msg in tree_messages}
# Find root messages (no parent)
roots = [msg for msg in tree_messages if not msg.get('parent_id')]
for root in roots:
try:
# Build all possible conversation paths from this root
paths = self.build_conversation_paths(root, msg_dict)
for path in paths:
# Convert to conversation format
conversation = []
for msg in path:
role = "user" if msg['role'] == "prompter" else "assistant"
conversation.append({
"role": role,
"content": msg['text'].strip()
})
# Validate conversation
if self.is_valid_conversation(conversation):
conversations.append({
"messages": conversation,
"tree_id": tree_id,
"source": "oasst1"
})
except Exception as e:
# Skip problematic trees
continue
print(f"Extracted {len(conversations)} valid conversations")
# Save processed conversations
output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
with open(output_path, "w", encoding="utf-8") as f:
for conv in conversations:
f.write(json.dumps(conv, ensure_ascii=False) + "\n")
print(f"Processed data saved to: {output_path}")
return output_path
def build_conversation_paths(self, root_msg, msg_dict, max_length=8):
"""Build all conversation paths starting from a root message - FIXED"""
def build_paths_recursive(msg, current_path):
paths = []
new_path = current_path + [msg]
# Find children of this message
children = []
for candidate in msg_dict.values():
if candidate.get('parent_id') == msg['message_id']:
children.append(candidate)
if not children:
# Leaf node - end of conversation path
if len(new_path) >= 2: # At least user + assistant
paths.append(new_path)
else:
# Continue with each child (take the best ranked one)
# Fix: Handle None values in rank
def get_rank(x):
rank = x.get('rank')
return rank if rank is not None else 999
try:
children.sort(key=get_rank) # Lower rank = better
best_child = children[0]
if len(new_path) < max_length: # Prevent very long conversations
child_paths = build_paths_recursive(best_child, new_path)
paths.extend(child_paths)
# Also save the current path if it's long enough
if len(new_path) >= 2:
paths.append(new_path)
except:
# If sorting fails, just take the first child
if children and len(new_path) < max_length:
child_paths = build_paths_recursive(children[0], new_path)
paths.extend(child_paths)
return paths
return build_paths_recursive(root_msg, [])
def is_valid_conversation(self, conversation):
"""Validate conversation quality"""
# Must have at least 2 messages
if len(conversation) < 2:
return False
# Check for alternating roles (user/assistant pattern)
for i in range(1, len(conversation)):
if conversation[i]['role'] == conversation[i-1]['role']:
return False
# Check message content quality
for msg in conversation:
content = msg['content']
if len(content) < 5 or len(content) > 1500:
return False
# Check total conversation length
total_length = sum(len(msg['content']) for msg in conversation)
if total_length < 20 or total_length > 3000:
return False
return True
def process_other_datasets(self, input_path):
"""Process non-OpenAssistant datasets (Dolly, Alpaca, etc.)"""
output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
conversations = []
total_count = 0
valid_count = 0
with open(input_path, "r", encoding="utf-8") as infile:
for line in tqdm(infile, desc="Processing conversations"):
total_count += 1
try:
raw_data = json.loads(line)
# Extract conversation based on format
conversation = self.extract_conversation_other_formats(raw_data)
if conversation and self.validate_simple_conversation(conversation):
conversations.append(conversation)
valid_count += 1
except Exception as e:
continue
# Save processed conversations
with open(output_path, "w", encoding="utf-8") as outfile:
for conv in conversations:
outfile.write(json.dumps(conv, ensure_ascii=False) + "\n")
print(f"Processed {valid_count}/{total_count} valid conversations")
print(f"Processed data saved to: {output_path}")
return output_path
def extract_conversation_other_formats(self, raw_data):
"""Extract conversation from various dataset formats"""
# Dolly format
if 'instruction' in raw_data and 'response' in raw_data:
messages = [
{"role": "user", "content": raw_data['instruction'].strip()}
]
if raw_data.get('context'):
messages[0]['content'] += f"\nContext: {raw_data['context'].strip()}"
messages.append({
"role": "assistant",
"content": raw_data['response'].strip()
})
return {
"messages": messages,
"category": raw_data.get('category', 'general'),
"source": "dolly"
}
# Alpaca format
elif 'instruction' in raw_data and 'output' in raw_data:
messages = [
{"role": "user", "content": raw_data['instruction'].strip()}
]
if raw_data.get('input'):
messages[0]['content'] += f"\nInput: {raw_data['input'].strip()}"
messages.append({
"role": "assistant",
"content": raw_data['output'].strip()
})
return {
"messages": messages,
"source": "alpaca"
}
return None
def validate_simple_conversation(self, conversation):
"""Validate conversation quality for simple formats"""
messages = conversation.get('messages', [])
# Must have at least 1 message
if len(messages) < 1:
return False
# Check message content
for msg in messages:
content = msg.get('content', '').strip()
if not content or len(content) < 5:
return False
# Check total length
total_length = sum(len(msg['content']) for msg in messages)
if total_length < 10 or total_length > 2000:
return False
return True
def format_for_training(self, input_path, train_format="instruction"):
"""Format conversations for fine-tuning"""
print(f"Formatting conversations for {train_format} training...")
input_path = Path(input_path)
output_path = self.output_dir / "conversation_final" / "conversation_train.jsonl"
test_path = self.output_dir / "conversation_final" / "conversation_test.jsonl"
conversations = []
# Load processed conversations
with open(input_path, "r", encoding="utf-8") as f:
for line in f:
conv = json.loads(line)
conversations.append(conv)
# Shuffle and split
random.shuffle(conversations)
split_point = int(len(conversations) * 0.9)
train_conversations = conversations[:split_point]
test_conversations = conversations[split_point:]
# Format for training
self.save_training_format(train_conversations, output_path, train_format)
self.save_training_format(test_conversations, test_path, train_format)
print(f"Training conversations: {len(train_conversations)}")
print(f"Test conversations: {len(test_conversations)}")
print(f"Training data saved to: {output_path}")
print(f"Test data saved to: {test_path}")
# Show samples
if train_conversations:
print("\n๐Ÿ“ Sample conversations:")
for i, conv in enumerate(train_conversations[:3]):
print(f"\nConversation {i+1}:")
for j, msg in enumerate(conv['messages']):
content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
print(f" {j+1}. {msg['role'].title()}: {content}")
return output_path, test_path
def save_training_format(self, conversations, output_path, format_type):
"""Save conversations in training format"""
with open(output_path, "w", encoding="utf-8") as f:
for conv in conversations:
messages = conv['messages']
if len(messages) >= 2:
if format_type == "instruction":
# Instruction format: last message is target, rest is input
input_messages = []
for msg in messages[:-1]:
input_messages.append(f"{msg['role'].title()}: {msg['content']}")
training_example = {
"instruction": "Continue this conversation naturally and helpfully.",
"input": "\n".join(input_messages),
"output": messages[-1]['content']
}
elif format_type == "chat":
# Chat format: full conversation with system prompt
training_example = {
"messages": [
{"role": "system", "content": "You are MAP-NEO, a helpful AI assistant."}
] + messages
}
f.write(json.dumps(training_example, ensure_ascii=False) + "\n")
def main():
parser = argparse.ArgumentParser(description="Preprocess conversational data for fine-tuning")
parser.add_argument("--dataset", type=str, default="OpenAssistant/oasst1",
help="Dataset to download")
parser.add_argument("--num_conversations", type=int, default=20000,
help="Number of conversations to download")
parser.add_argument("--format", type=str, default="instruction",
choices=["instruction", "chat"],
help="Training format")
parser.add_argument("--output_dir", type=str, default="data",
help="Output directory")
args = parser.parse_args()
# Initialize preprocessor
preprocessor = ConversationDataPreprocessor(args.output_dir)
# Run pipeline
print("Starting conversational data preprocessing pipeline...")
# Step 1: Download conversational data
raw_path = preprocessor.download_conversational_data(
args.dataset, args.num_conversations
)
# Step 2: Process conversations (auto-detects OpenAssistant vs others)
processed_path = preprocessor.process_conversations(raw_path, args.dataset)
# Step 3: Format for training
train_path, test_path = preprocessor.format_for_training(
processed_path, args.format
)
print("\n" + "="*60)
print("๐ŸŽ‰ Conversational data preprocessing complete!")
print(f"Training data: {train_path}")
print(f"Test data: {test_path}")
print("\n๐Ÿš€ Ready for conversational fine-tuning!")
print("Next step: python finetune_conversational.py")
print("="*60)
if __name__ == "__main__":
main()