|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
import os
|
|
|
import itertools
|
|
|
from pathlib import Path
|
|
|
from datasets import load_dataset
|
|
|
from transformers import AutoTokenizer
|
|
|
import langdetect
|
|
|
from tqdm import tqdm
|
|
|
import argparse
|
|
|
import random
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
class ConversationDataPreprocessor:
|
|
|
def __init__(self, output_dir="data", max_length=1024):
|
|
|
self.output_dir = Path(output_dir)
|
|
|
self.max_length = max_length
|
|
|
self.setup_directories()
|
|
|
|
|
|
def setup_directories(self):
|
|
|
"""Create necessary directories"""
|
|
|
dirs = ["conversation_raw", "conversation_processed", "conversation_final"]
|
|
|
for d in dirs:
|
|
|
(self.output_dir / d).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def download_conversational_data(self, dataset_name="OpenAssistant/oasst1", num_conversations=20000):
|
|
|
"""Download conversational dataset from HuggingFace"""
|
|
|
print(f"Downloading {num_conversations} conversations from {dataset_name}...")
|
|
|
|
|
|
raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
|
|
|
|
|
try:
|
|
|
|
|
|
ds = load_dataset(dataset_name, split="train", streaming=True)
|
|
|
|
|
|
downloaded = 0
|
|
|
with open(raw_path, "w", encoding="utf-8") as f:
|
|
|
for row in tqdm(itertools.islice(ds, num_conversations), total=num_conversations):
|
|
|
|
|
|
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
|
downloaded += 1
|
|
|
|
|
|
print(f"Raw conversational data saved to: {raw_path}")
|
|
|
print(f"Downloaded {downloaded} conversation records")
|
|
|
return raw_path
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error downloading {dataset_name}: {e}")
|
|
|
print("Trying alternative dataset...")
|
|
|
return self.download_alternative_dataset(num_conversations)
|
|
|
|
|
|
def download_alternative_dataset(self, num_conversations=20000):
|
|
|
"""Try alternative conversational datasets if primary fails"""
|
|
|
alternative_datasets = [
|
|
|
"databricks/databricks-dolly-15k",
|
|
|
"tatsu-lab/alpaca",
|
|
|
"vicgalle/alpaca-gpt4"
|
|
|
]
|
|
|
|
|
|
for dataset_name in alternative_datasets:
|
|
|
try:
|
|
|
print(f"Trying {dataset_name}...")
|
|
|
raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
|
|
|
|
|
ds = load_dataset(dataset_name, split="train")
|
|
|
|
|
|
|
|
|
if len(ds) > num_conversations:
|
|
|
ds = ds.shuffle(seed=42).select(range(num_conversations))
|
|
|
|
|
|
with open(raw_path, "w", encoding="utf-8") as f:
|
|
|
for row in tqdm(ds):
|
|
|
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
|
|
|
|
print(f"Successfully downloaded {len(ds)} records from {dataset_name}")
|
|
|
return raw_path
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Failed to download {dataset_name}: {e}")
|
|
|
continue
|
|
|
|
|
|
raise Exception("All conversational datasets failed to download")
|
|
|
|
|
|
def process_conversations(self, input_path, dataset_name="auto"):
|
|
|
"""Process raw conversational data into standard format"""
|
|
|
print("Processing conversations into standard format...")
|
|
|
|
|
|
input_path = Path(input_path)
|
|
|
|
|
|
|
|
|
if "OpenAssistant" in str(input_path) or "oasst" in str(input_path):
|
|
|
return self.process_openassistant_messages(input_path)
|
|
|
else:
|
|
|
return self.process_other_datasets(input_path)
|
|
|
|
|
|
def process_openassistant_messages(self, input_path):
|
|
|
"""Process OpenAssistant individual messages into conversation chains"""
|
|
|
|
|
|
print("๐ Processing OpenAssistant messages into conversations...")
|
|
|
|
|
|
|
|
|
messages = []
|
|
|
print("Loading messages...")
|
|
|
|
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
|
|
for line in tqdm(f, desc="Reading messages"):
|
|
|
try:
|
|
|
msg = json.loads(line)
|
|
|
|
|
|
if (msg.get('lang') == 'en' and
|
|
|
not msg.get('deleted', False) and
|
|
|
msg.get('review_result', False) and
|
|
|
msg.get('text', '').strip()):
|
|
|
|
|
|
messages.append(msg)
|
|
|
except:
|
|
|
continue
|
|
|
|
|
|
print(f"Loaded {len(messages)} valid English messages")
|
|
|
|
|
|
|
|
|
trees = defaultdict(list)
|
|
|
for msg in messages:
|
|
|
tree_id = msg.get('message_tree_id')
|
|
|
if tree_id:
|
|
|
trees[tree_id].append(msg)
|
|
|
|
|
|
print(f"Found {len(trees)} conversation trees")
|
|
|
|
|
|
|
|
|
conversations = []
|
|
|
|
|
|
for tree_id, tree_messages in tqdm(trees.items(), desc="Building conversations"):
|
|
|
|
|
|
msg_dict = {msg['message_id']: msg for msg in tree_messages}
|
|
|
|
|
|
|
|
|
roots = [msg for msg in tree_messages if not msg.get('parent_id')]
|
|
|
|
|
|
for root in roots:
|
|
|
try:
|
|
|
|
|
|
paths = self.build_conversation_paths(root, msg_dict)
|
|
|
|
|
|
for path in paths:
|
|
|
|
|
|
conversation = []
|
|
|
for msg in path:
|
|
|
role = "user" if msg['role'] == "prompter" else "assistant"
|
|
|
conversation.append({
|
|
|
"role": role,
|
|
|
"content": msg['text'].strip()
|
|
|
})
|
|
|
|
|
|
|
|
|
if self.is_valid_conversation(conversation):
|
|
|
conversations.append({
|
|
|
"messages": conversation,
|
|
|
"tree_id": tree_id,
|
|
|
"source": "oasst1"
|
|
|
})
|
|
|
except Exception as e:
|
|
|
|
|
|
continue
|
|
|
|
|
|
print(f"Extracted {len(conversations)} valid conversations")
|
|
|
|
|
|
|
|
|
output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
|
for conv in conversations:
|
|
|
f.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
|
|
|
|
|
print(f"Processed data saved to: {output_path}")
|
|
|
return output_path
|
|
|
|
|
|
def build_conversation_paths(self, root_msg, msg_dict, max_length=8):
|
|
|
"""Build all conversation paths starting from a root message - FIXED"""
|
|
|
|
|
|
def build_paths_recursive(msg, current_path):
|
|
|
paths = []
|
|
|
new_path = current_path + [msg]
|
|
|
|
|
|
|
|
|
children = []
|
|
|
for candidate in msg_dict.values():
|
|
|
if candidate.get('parent_id') == msg['message_id']:
|
|
|
children.append(candidate)
|
|
|
|
|
|
if not children:
|
|
|
|
|
|
if len(new_path) >= 2:
|
|
|
paths.append(new_path)
|
|
|
else:
|
|
|
|
|
|
|
|
|
def get_rank(x):
|
|
|
rank = x.get('rank')
|
|
|
return rank if rank is not None else 999
|
|
|
|
|
|
try:
|
|
|
children.sort(key=get_rank)
|
|
|
best_child = children[0]
|
|
|
|
|
|
if len(new_path) < max_length:
|
|
|
child_paths = build_paths_recursive(best_child, new_path)
|
|
|
paths.extend(child_paths)
|
|
|
|
|
|
|
|
|
if len(new_path) >= 2:
|
|
|
paths.append(new_path)
|
|
|
except:
|
|
|
|
|
|
if children and len(new_path) < max_length:
|
|
|
child_paths = build_paths_recursive(children[0], new_path)
|
|
|
paths.extend(child_paths)
|
|
|
|
|
|
return paths
|
|
|
|
|
|
return build_paths_recursive(root_msg, [])
|
|
|
|
|
|
def is_valid_conversation(self, conversation):
|
|
|
"""Validate conversation quality"""
|
|
|
|
|
|
|
|
|
if len(conversation) < 2:
|
|
|
return False
|
|
|
|
|
|
|
|
|
for i in range(1, len(conversation)):
|
|
|
if conversation[i]['role'] == conversation[i-1]['role']:
|
|
|
return False
|
|
|
|
|
|
|
|
|
for msg in conversation:
|
|
|
content = msg['content']
|
|
|
if len(content) < 5 or len(content) > 1500:
|
|
|
return False
|
|
|
|
|
|
|
|
|
total_length = sum(len(msg['content']) for msg in conversation)
|
|
|
if total_length < 20 or total_length > 3000:
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def process_other_datasets(self, input_path):
|
|
|
"""Process non-OpenAssistant datasets (Dolly, Alpaca, etc.)"""
|
|
|
|
|
|
output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
|
|
conversations = []
|
|
|
total_count = 0
|
|
|
valid_count = 0
|
|
|
|
|
|
with open(input_path, "r", encoding="utf-8") as infile:
|
|
|
for line in tqdm(infile, desc="Processing conversations"):
|
|
|
total_count += 1
|
|
|
try:
|
|
|
raw_data = json.loads(line)
|
|
|
|
|
|
|
|
|
conversation = self.extract_conversation_other_formats(raw_data)
|
|
|
|
|
|
if conversation and self.validate_simple_conversation(conversation):
|
|
|
conversations.append(conversation)
|
|
|
valid_count += 1
|
|
|
|
|
|
except Exception as e:
|
|
|
continue
|
|
|
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as outfile:
|
|
|
for conv in conversations:
|
|
|
outfile.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
|
|
|
|
|
print(f"Processed {valid_count}/{total_count} valid conversations")
|
|
|
print(f"Processed data saved to: {output_path}")
|
|
|
return output_path
|
|
|
|
|
|
def extract_conversation_other_formats(self, raw_data):
|
|
|
"""Extract conversation from various dataset formats"""
|
|
|
|
|
|
|
|
|
if 'instruction' in raw_data and 'response' in raw_data:
|
|
|
messages = [
|
|
|
{"role": "user", "content": raw_data['instruction'].strip()}
|
|
|
]
|
|
|
if raw_data.get('context'):
|
|
|
messages[0]['content'] += f"\nContext: {raw_data['context'].strip()}"
|
|
|
|
|
|
messages.append({
|
|
|
"role": "assistant",
|
|
|
"content": raw_data['response'].strip()
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"messages": messages,
|
|
|
"category": raw_data.get('category', 'general'),
|
|
|
"source": "dolly"
|
|
|
}
|
|
|
|
|
|
|
|
|
elif 'instruction' in raw_data and 'output' in raw_data:
|
|
|
messages = [
|
|
|
{"role": "user", "content": raw_data['instruction'].strip()}
|
|
|
]
|
|
|
if raw_data.get('input'):
|
|
|
messages[0]['content'] += f"\nInput: {raw_data['input'].strip()}"
|
|
|
|
|
|
messages.append({
|
|
|
"role": "assistant",
|
|
|
"content": raw_data['output'].strip()
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"messages": messages,
|
|
|
"source": "alpaca"
|
|
|
}
|
|
|
|
|
|
return None
|
|
|
|
|
|
def validate_simple_conversation(self, conversation):
|
|
|
"""Validate conversation quality for simple formats"""
|
|
|
messages = conversation.get('messages', [])
|
|
|
|
|
|
|
|
|
if len(messages) < 1:
|
|
|
return False
|
|
|
|
|
|
|
|
|
for msg in messages:
|
|
|
content = msg.get('content', '').strip()
|
|
|
if not content or len(content) < 5:
|
|
|
return False
|
|
|
|
|
|
|
|
|
total_length = sum(len(msg['content']) for msg in messages)
|
|
|
if total_length < 10 or total_length > 2000:
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def format_for_training(self, input_path, train_format="instruction"):
|
|
|
"""Format conversations for fine-tuning"""
|
|
|
print(f"Formatting conversations for {train_format} training...")
|
|
|
|
|
|
input_path = Path(input_path)
|
|
|
output_path = self.output_dir / "conversation_final" / "conversation_train.jsonl"
|
|
|
test_path = self.output_dir / "conversation_final" / "conversation_test.jsonl"
|
|
|
|
|
|
conversations = []
|
|
|
|
|
|
|
|
|
with open(input_path, "r", encoding="utf-8") as f:
|
|
|
for line in f:
|
|
|
conv = json.loads(line)
|
|
|
conversations.append(conv)
|
|
|
|
|
|
|
|
|
random.shuffle(conversations)
|
|
|
split_point = int(len(conversations) * 0.9)
|
|
|
train_conversations = conversations[:split_point]
|
|
|
test_conversations = conversations[split_point:]
|
|
|
|
|
|
|
|
|
self.save_training_format(train_conversations, output_path, train_format)
|
|
|
self.save_training_format(test_conversations, test_path, train_format)
|
|
|
|
|
|
print(f"Training conversations: {len(train_conversations)}")
|
|
|
print(f"Test conversations: {len(test_conversations)}")
|
|
|
print(f"Training data saved to: {output_path}")
|
|
|
print(f"Test data saved to: {test_path}")
|
|
|
|
|
|
|
|
|
if train_conversations:
|
|
|
print("\n๐ Sample conversations:")
|
|
|
for i, conv in enumerate(train_conversations[:3]):
|
|
|
print(f"\nConversation {i+1}:")
|
|
|
for j, msg in enumerate(conv['messages']):
|
|
|
content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
|
|
|
print(f" {j+1}. {msg['role'].title()}: {content}")
|
|
|
|
|
|
return output_path, test_path
|
|
|
|
|
|
def save_training_format(self, conversations, output_path, format_type):
|
|
|
"""Save conversations in training format"""
|
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
|
for conv in conversations:
|
|
|
messages = conv['messages']
|
|
|
|
|
|
if len(messages) >= 2:
|
|
|
if format_type == "instruction":
|
|
|
|
|
|
input_messages = []
|
|
|
for msg in messages[:-1]:
|
|
|
input_messages.append(f"{msg['role'].title()}: {msg['content']}")
|
|
|
|
|
|
training_example = {
|
|
|
"instruction": "Continue this conversation naturally and helpfully.",
|
|
|
"input": "\n".join(input_messages),
|
|
|
"output": messages[-1]['content']
|
|
|
}
|
|
|
|
|
|
elif format_type == "chat":
|
|
|
|
|
|
training_example = {
|
|
|
"messages": [
|
|
|
{"role": "system", "content": "You are MAP-NEO, a helpful AI assistant."}
|
|
|
] + messages
|
|
|
}
|
|
|
|
|
|
f.write(json.dumps(training_example, ensure_ascii=False) + "\n")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Preprocess conversational data for fine-tuning")
|
|
|
parser.add_argument("--dataset", type=str, default="OpenAssistant/oasst1",
|
|
|
help="Dataset to download")
|
|
|
parser.add_argument("--num_conversations", type=int, default=20000,
|
|
|
help="Number of conversations to download")
|
|
|
parser.add_argument("--format", type=str, default="instruction",
|
|
|
choices=["instruction", "chat"],
|
|
|
help="Training format")
|
|
|
parser.add_argument("--output_dir", type=str, default="data",
|
|
|
help="Output directory")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
preprocessor = ConversationDataPreprocessor(args.output_dir)
|
|
|
|
|
|
|
|
|
print("Starting conversational data preprocessing pipeline...")
|
|
|
|
|
|
|
|
|
raw_path = preprocessor.download_conversational_data(
|
|
|
args.dataset, args.num_conversations
|
|
|
)
|
|
|
|
|
|
|
|
|
processed_path = preprocessor.process_conversations(raw_path, args.dataset)
|
|
|
|
|
|
|
|
|
train_path, test_path = preprocessor.format_for_training(
|
|
|
processed_path, args.format
|
|
|
)
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("๐ Conversational data preprocessing complete!")
|
|
|
print(f"Training data: {train_path}")
|
|
|
print(f"Test data: {test_path}")
|
|
|
print("\n๐ Ready for conversational fine-tuning!")
|
|
|
print("Next step: python finetune_conversational.py")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|