iamrazi commited on
Commit
5bbfc9f
·
verified ·
1 Parent(s): 32a5a2f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -29
README.md CHANGED
@@ -22,42 +22,48 @@ tags:
22
  # Load model directly
23
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
24
 
25
- tokenizer = AutoTokenizer.from_pretrained("iamrazi/text-moderation")
 
26
  model = AutoModelForSequenceClassification.from_pretrained("iamrazi/text-moderation")
27
 
28
  model.eval() # Set model to evaluation mode
29
 
30
- def predict_abuse(text: str, threshold: float = 0.5):
31
- """
32
- Predict if a text is abusive or not.
33
-
34
- Args:
35
- text (str): Input text.
36
- threshold (float): Probability threshold for classification.
37
-
38
- Returns:
39
- label (int): 0 for non-abusive, 1 for abusive
40
- proba (float): Probability of being abusive
41
- """
42
- # Tokenize
43
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
44
-
45
- # Forward pass
46
- with torch.no_grad():
47
- outputs = model(**inputs)
48
- logits = outputs.logits
49
- probas = torch.sigmoid(logits) # if your model output layer is logits
50
-
51
- # For binary classification, take the probability of class 1
52
- prob = probas[0][1].item() if probas.shape[1] > 1 else probas[0][0].item()
53
-
54
- # Determine label
55
- label = 1 if prob >= threshold else 0
56
-
57
- return label, prob
 
 
 
58
 
59
 
60
  text = "तुम बहुत गंदे हो 😡"
 
 
61
  label, proba = predict_abuse(text)
62
 
63
  Output: Label: 0, Probability: 0.08
 
22
  # Load model directly
23
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
24
 
25
+ tokenizer = AutoTokenizer.from_pretrained("iamrazi/text-moderation") #
26
+
27
  model = AutoModelForSequenceClassification.from_pretrained("iamrazi/text-moderation")
28
 
29
  model.eval() # Set model to evaluation mode
30
 
31
+
32
+
33
+ def predict_abuse(text: str, threshold: float = 0.5):
34
+
35
+ """
36
+ Predict if a text is abusive or not.
37
+
38
+ Args:
39
+ text (str): Input text.
40
+ threshold (float): Probability threshold for classification.
41
+
42
+ Returns:
43
+ label (int): 0 for non-abusive, 1 for abusive
44
+ proba (float): Probability of being abusive
45
+ """
46
+ # Tokenize
47
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
48
+
49
+ # Forward pass
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits
53
+ probas = torch.sigmoid(logits) # if your model output layer is logits
54
+
55
+ # For binary classification, take the probability of class 1
56
+ prob = probas[0][1].item() if probas.shape[1] > 1 else probas[0][0].item()
57
+
58
+ # Determine label
59
+ label = 1 if prob >= threshold else 0
60
+
61
+ return label, prob
62
 
63
 
64
  text = "तुम बहुत गंदे हो 😡"
65
+
66
+
67
  label, proba = predict_abuse(text)
68
 
69
  Output: Label: 0, Probability: 0.08