S-Dreamer commited on
Commit
a2f7f62
·
verified ·
1 Parent(s): f4c8195

Update model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +16 -26
model_inference.py CHANGED
@@ -4,54 +4,44 @@ try:
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import torch
6
  except ImportError:
7
- AutoModelForSequenceClassification = None # type: ignore
8
- AutoTokenizer = None # type: ignore
9
- torch = None # type: ignore
10
 
11
 
12
  class ThreatModel:
13
  """
14
- Wraps a transformer classifier for threat categorization.
15
-
16
- If `transformers` or `torch` are not installed, this class will gracefully
17
- degrade and simply return empty probability lists instead of crashing.
18
  """
19
-
20
- def __init__(self, model_path: str = "bert-base-chinese", device: Optional[str] = None):
21
  self.available = AutoModelForSequenceClassification is not None and torch is not None
22
  self.model = None
23
  self.tokenizer = None
24
  self.device = "cpu"
25
 
26
  if not self.available:
27
- # No transformers / torch in the environment; operate in dummy mode.
28
  return
29
 
30
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # type: ignore[attr-defined]
31
- self.tokenizer = AutoTokenizer.from_pretrained(model_path) # type: ignore[call-arg]
32
- self.model = AutoModelForSequenceClassification.from_pretrained(model_path) # type: ignore[call-arg]
33
- self.model.to(self.device) # type: ignore[union-attr]
34
 
35
  def predict_proba(self, text: str) -> List[float]:
36
- """
37
- Return a list of probabilities per class.
38
-
39
- If the model is not available (e.g. transformers not installed),
40
- returns an empty list and lets the caller decide how to handle it.
41
- """
42
  if not self.available or self.model is None or self.tokenizer is None:
43
  return []
44
 
45
- inputs = self.tokenizer( # type: ignore[union-attr]
46
  text,
47
  return_tensors="pt",
48
  truncation=True,
49
  padding=True
50
- ).to(self.device) # type: ignore[union-attr]
51
 
52
- with torch.no_grad(): # type: ignore[union-attr]
53
- outputs = self.model(**inputs) # type: ignore[operator]
54
- logits = outputs.logits # type: ignore[union-attr]
55
- probs = torch.softmax(logits, dim=-1).cpu().tolist()[0] # type: ignore[union-attr]
56
 
57
  return probs
 
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import torch
6
  except ImportError:
7
+ AutoModelForSequenceClassification = None
8
+ AutoTokenizer = None
9
+ torch = None
10
 
11
 
12
  class ThreatModel:
13
  """
14
+ Transformer wrapper. If transformers is not installed,
15
+ falls back to dummy mode and returns empty probabilities.
 
 
16
  """
17
+ def __init__(self, model_path: str, device: Optional[str] = None):
 
18
  self.available = AutoModelForSequenceClassification is not None and torch is not None
19
  self.model = None
20
  self.tokenizer = None
21
  self.device = "cpu"
22
 
23
  if not self.available:
 
24
  return
25
 
26
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
28
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
29
+ self.model.to(self.device)
30
 
31
  def predict_proba(self, text: str) -> List[float]:
 
 
 
 
 
 
32
  if not self.available or self.model is None or self.tokenizer is None:
33
  return []
34
 
35
+ inputs = self.tokenizer(
36
  text,
37
  return_tensors="pt",
38
  truncation=True,
39
  padding=True
40
+ ).to(self.device)
41
 
42
+ with torch.no_grad():
43
+ outputs = self.model(**inputs)
44
+ logits = outputs.logits
45
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()[0]
46
 
47
  return probs