S-Dreamer commited on
Commit
75c4517
·
verified ·
1 Parent(s): b65b044

Create model_inference.py

Browse files
Files changed (1) hide show
  1. model_inference.py +30 -0
model_inference.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_inference.py
2
+
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ import torch
5
+
6
+ class ThreatModel:
7
+ """
8
+ Wraps a transformer classifier for threat categorization.
9
+ """
10
+
11
+ def __init__(self, model_path="bert-base-chinese", device=None):
12
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
15
+ self.model.to(self.device)
16
+
17
+ def predict(self, text):
18
+ inputs = self.tokenizer(
19
+ text,
20
+ return_tensors="pt",
21
+ truncation=True,
22
+ padding=True
23
+ ).to(self.device)
24
+
25
+ with torch.no_grad():
26
+ outputs = self.model(**inputs)
27
+ logits = outputs.logits
28
+ probs = torch.softmax(logits, dim=-1).cpu().tolist()[0]
29
+
30
+ return probs # list of probabilities per class