""" safety_classifier.py This module integrates two red-teaming datasets: 1. romaingrx/red-teamer-mistral- 2. SummerSigh/Muti-Class-Redteaming It provides: - request safety classification - output safety validation - heuristic detectors based on adversarial patterns - optional hooks for model-based classification The goal is defensive: prevent harmful prompts or unsafe completions from flowing through the APJ Threat Intelligence system. """ import re from typing import Dict, List, Optional from datasets import load_dataset # --------------------------------------------------------------------------- # Load datasets (shallow load only) # --------------------------------------------------------------------------- try: DATA_REDTEAM_MISTRAL = load_dataset("romaingrx/red-teamer-mistral-", split="train") except Exception: DATA_REDTEAM_MISTRAL = None try: DATA_MULTI_REDT = load_dataset("SummerSigh/Muti-Class-Redteaming", split="train") except Exception: DATA_MULTI_REDT = None # --------------------------------------------------------------------------- # Extract patterns from datasets (heuristic layer) # --------------------------------------------------------------------------- def extract_lexical_patterns(dataset, text_key="text") -> List[str]: """Extract simple lexical redteam cues from dataset text.""" if not dataset: return [] patterns = [] for ex in dataset: txt = ex.get(text_key) if not isinstance(txt, str): continue # Extract some keyword-like substrings # We use lowercased lexical tokens as heuristic patterns. tokens = re.findall(r"[a-zA-Z0-9_\-]{4,}", txt.lower()) patterns.extend(tokens) # Deduplicate + filter noise uniq = list({p for p in patterns if len(p) > 4}) return uniq[:5000] # cap to avoid overloading # Build heuristic lexicons only once LEX_REDTEAM_MISTRAL = extract_lexical_patterns(DATA_REDTEAM_MISTRAL) LEX_REDTEAM_MULTI = extract_lexical_patterns(DATA_MULTI_REDT) # Combined lexical bank LEXICAL_BANK = set(LEX_REDTEAM_MISTRAL + LEX_REDTEAM_MULTI) # --------------------------------------------------------------------------- # Simple dangerous intent heuristic matcher # --------------------------------------------------------------------------- DANGEROUS_INTENTS = [ # Offensive cyber → blocked r"(ddos|botnet|backdoor|rat|exploit|zero[- ]?day)", r"(bypass|evasion|jailbreak|prompt[- ]?hack)", r"(hack|breach|compromise|ransomware)", r"(sql[- ]?injection|buffer[- ]?overflow)", # Illicit markets (to prevent misuse of your system) r"(buy|sell).*(credentials|logins|databases|access)", r"(stolen|leaked).*(data|info)", r"(carding|skimming|fraud)", # Violence / harassment (red-teaming dataset content) r"(kill|assault|harass|suicide|bomb|weapon)", ] dangerous_intent_res = [re.compile(p, re.IGNORECASE) for p in DANGEROUS_INTENTS] def detect_dangerous_intent(text: str) -> bool: """Return True if the text matches known dangerous intent patterns.""" for r in dangerous_intent_res: if r.search(text): return True return False # --------------------------------------------------------------------------- # Lexical similarity heuristic # --------------------------------------------------------------------------- def heuristic_lexical_overlap(text: str, threshold: int = 5) -> bool: """ Check how many red-team tokens appear in the text. If overlap exceeds threshold, classify as suspicious. """ if not text: return False tokens = set(re.findall(r"[a-zA-Z0-9_\-]{4,}", text.lower())) overlap = tokens.intersection(LEXICAL_BANK) return len(overlap) >= threshold # --------------------------------------------------------------------------- # Optional future ML classifier hooks (currently placeholder) # --------------------------------------------------------------------------- def ml_classify_request(text: str) -> Optional[str]: """ Placeholder for future ML classification using fine-tuned models. Expected return values: - "safe" - "suspicious" - "dangerous" """ return None def ml_classify_output(text: str) -> Optional[str]: """Same as above—placeholder for model-based output safety filters.""" return None # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def safety_check(text: str) -> Dict[str, str]: """ Main safety gate for incoming user text. Returns: { "blocked": True/False, "reason": "...", "level": "safe/suspicious/dangerous" } """ t = (text or "").strip().lower() # 1. ML classification (if implemented later) ml = ml_classify_request(t) if ml == "dangerous": return { "blocked": True, "reason": "⚠️ ML safety classifier flagged this as dangerous.", "level": "dangerous", } # 2. Dangerous intent patterns if detect_dangerous_intent(t): return { "blocked": True, "reason": "⚠️ Request blocked due to dangerous intent indicators.", "level": "dangerous", } # 3. Lexical overlap heuristic if heuristic_lexical_overlap(t): return { "blocked": False, "reason": "⚠️ High lexical similarity to red-team prompts.", "level": "suspicious", } return { "blocked": False, "reason": "Safe request.", "level": "safe", } def safety_check_output(text: str) -> Dict[str, str]: """ Validate generated model output. """ t = (text or "").strip().lower() # 1. ML classification (future) ml = ml_classify_output(t) if ml == "dangerous": return { "blocked": True, "reason": "⚠️ Unsafe model output detected by classifier.", "level": "dangerous", } # 2. Dangerous intent patterns if detect_dangerous_intent(t): return { "blocked": True, "reason": "⚠️ Model output contains dangerous intent content.", "level": "dangerous", } # 3. Lexical overlap if heuristic_lexical_overlap(t, threshold=8): # tighten for output return { "blocked": False, "reason": "⚠️ Output resembles adversarial red-team patterns.", "level": "suspicious", } return { "blocked": False, "reason": "Output appears safe.", "level": "safe", }