ThtratLandscapeChat / safety_classifier.py
S-Dreamer's picture
Create safety_classifier.py
29428af verified
raw
history blame
6.77 kB
"""
safety_classifier.py
This module integrates two red-teaming datasets:
1. romaingrx/red-teamer-mistral-
2. SummerSigh/Muti-Class-Redteaming
It provides:
- request safety classification
- output safety validation
- heuristic detectors based on adversarial patterns
- optional hooks for model-based classification
The goal is defensive: prevent harmful prompts or unsafe completions
from flowing through the APJ Threat Intelligence system.
"""
import re
from typing import Dict, List, Optional
from datasets import load_dataset
# ---------------------------------------------------------------------------
# Load datasets (shallow load only)
# ---------------------------------------------------------------------------
try:
DATA_REDTEAM_MISTRAL = load_dataset("romaingrx/red-teamer-mistral-", split="train")
except Exception:
DATA_REDTEAM_MISTRAL = None
try:
DATA_MULTI_REDT = load_dataset("SummerSigh/Muti-Class-Redteaming", split="train")
except Exception:
DATA_MULTI_REDT = None
# ---------------------------------------------------------------------------
# Extract patterns from datasets (heuristic layer)
# ---------------------------------------------------------------------------
def extract_lexical_patterns(dataset, text_key="text") -> List[str]:
"""Extract simple lexical redteam cues from dataset text."""
if not dataset:
return []
patterns = []
for ex in dataset:
txt = ex.get(text_key)
if not isinstance(txt, str):
continue
# Extract some keyword-like substrings
# We use lowercased lexical tokens as heuristic patterns.
tokens = re.findall(r"[a-zA-Z0-9_\-]{4,}", txt.lower())
patterns.extend(tokens)
# Deduplicate + filter noise
uniq = list({p for p in patterns if len(p) > 4})
return uniq[:5000] # cap to avoid overloading
# Build heuristic lexicons only once
LEX_REDTEAM_MISTRAL = extract_lexical_patterns(DATA_REDTEAM_MISTRAL)
LEX_REDTEAM_MULTI = extract_lexical_patterns(DATA_MULTI_REDT)
# Combined lexical bank
LEXICAL_BANK = set(LEX_REDTEAM_MISTRAL + LEX_REDTEAM_MULTI)
# ---------------------------------------------------------------------------
# Simple dangerous intent heuristic matcher
# ---------------------------------------------------------------------------
DANGEROUS_INTENTS = [
# Offensive cyber → blocked
r"(ddos|botnet|backdoor|rat|exploit|zero[- ]?day)",
r"(bypass|evasion|jailbreak|prompt[- ]?hack)",
r"(hack|breach|compromise|ransomware)",
r"(sql[- ]?injection|buffer[- ]?overflow)",
# Illicit markets (to prevent misuse of your system)
r"(buy|sell).*(credentials|logins|databases|access)",
r"(stolen|leaked).*(data|info)",
r"(carding|skimming|fraud)",
# Violence / harassment (red-teaming dataset content)
r"(kill|assault|harass|suicide|bomb|weapon)",
]
dangerous_intent_res = [re.compile(p, re.IGNORECASE) for p in DANGEROUS_INTENTS]
def detect_dangerous_intent(text: str) -> bool:
"""Return True if the text matches known dangerous intent patterns."""
for r in dangerous_intent_res:
if r.search(text):
return True
return False
# ---------------------------------------------------------------------------
# Lexical similarity heuristic
# ---------------------------------------------------------------------------
def heuristic_lexical_overlap(text: str, threshold: int = 5) -> bool:
"""
Check how many red-team tokens appear in the text.
If overlap exceeds threshold, classify as suspicious.
"""
if not text:
return False
tokens = set(re.findall(r"[a-zA-Z0-9_\-]{4,}", text.lower()))
overlap = tokens.intersection(LEXICAL_BANK)
return len(overlap) >= threshold
# ---------------------------------------------------------------------------
# Optional future ML classifier hooks (currently placeholder)
# ---------------------------------------------------------------------------
def ml_classify_request(text: str) -> Optional[str]:
"""
Placeholder for future ML classification using fine-tuned models.
Expected return values:
- "safe"
- "suspicious"
- "dangerous"
"""
return None
def ml_classify_output(text: str) -> Optional[str]:
"""Same as above—placeholder for model-based output safety filters."""
return None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def safety_check(text: str) -> Dict[str, str]:
"""
Main safety gate for incoming user text.
Returns:
{
"blocked": True/False,
"reason": "...",
"level": "safe/suspicious/dangerous"
}
"""
t = (text or "").strip().lower()
# 1. ML classification (if implemented later)
ml = ml_classify_request(t)
if ml == "dangerous":
return {
"blocked": True,
"reason": "⚠️ ML safety classifier flagged this as dangerous.",
"level": "dangerous",
}
# 2. Dangerous intent patterns
if detect_dangerous_intent(t):
return {
"blocked": True,
"reason": "⚠️ Request blocked due to dangerous intent indicators.",
"level": "dangerous",
}
# 3. Lexical overlap heuristic
if heuristic_lexical_overlap(t):
return {
"blocked": False,
"reason": "⚠️ High lexical similarity to red-team prompts.",
"level": "suspicious",
}
return {
"blocked": False,
"reason": "Safe request.",
"level": "safe",
}
def safety_check_output(text: str) -> Dict[str, str]:
"""
Validate generated model output.
"""
t = (text or "").strip().lower()
# 1. ML classification (future)
ml = ml_classify_output(t)
if ml == "dangerous":
return {
"blocked": True,
"reason": "⚠️ Unsafe model output detected by classifier.",
"level": "dangerous",
}
# 2. Dangerous intent patterns
if detect_dangerous_intent(t):
return {
"blocked": True,
"reason": "⚠️ Model output contains dangerous intent content.",
"level": "dangerous",
}
# 3. Lexical overlap
if heuristic_lexical_overlap(t, threshold=8): # tighten for output
return {
"blocked": False,
"reason": "⚠️ Output resembles adversarial red-team patterns.",
"level": "suspicious",
}
return {
"blocked": False,
"reason": "Output appears safe.",
"level": "safe",
}