uogoit commited on
Commit
4748976
·
verified ·
1 Parent(s): 7a22c82

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoConfig,
5
+ AutoModelForSequenceClassification,
6
+ AutoTokenizer,
7
+ )
8
+
9
+ # ===== 基本配置 =====
10
+ MODEL_DIR = "my-bert-model"
11
+ MAX_LENGTH = 512
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # ===== 加载模型 =====
16
+ config = AutoConfig.from_pretrained(
17
+ MODEL_DIR,
18
+ num_labels=3,
19
+ finetuning_task="text-classification",
20
+ )
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
23
+
24
+ model = AutoModelForSequenceClassification.from_pretrained(
25
+ MODEL_DIR,
26
+ config=config
27
+ ).to(device)
28
+
29
+ model.eval()
30
+
31
+ # 若未定义 id2label,则自动生成
32
+ if not hasattr(model.config, "id2label") or not model.config.id2label:
33
+ model.config.id2label = {i: f"LABEL_{i}" for i in range(model.config.num_labels)}
34
+
35
+ # ===== 推理函数 =====
36
+ def inference(input_text: str) -> str:
37
+ if not input_text or not input_text.strip():
38
+ return "Empty input."
39
+
40
+ inputs = tokenizer(
41
+ input_text,
42
+ max_length=MAX_LENGTH,
43
+ truncation=True,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ inputs = {k: v.to(device) for k, v in inputs.items()}
49
+
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits
53
+
54
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
55
+ label = model.config.id2label.get(predicted_class_id, str(predicted_class_id))
56
+
57
+ return label
58
+
59
+ # ===== Gradio 界面 =====
60
+ demo = gr.Interface(
61
+ fn=inference,
62
+ inputs=gr.Textbox(
63
+ label="Input Text",
64
+ placeholder="Enter text to classify...",
65
+ lines=5,
66
+ ),
67
+ outputs=gr.Textbox(label="Predicted Label"),
68
+ examples=[
69
+ ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up."],
70
+ ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!"],
71
+ ],
72
+ title="BERT-based Text Classification",
73
+ description="A text classification demo powered by a fine-tuned BERT model.",
74
+ )
75
+
76
+ # ===== 启动 =====
77
+ if __name__ == "__main__":
78
+ demo.launch(
79
+ debug=False,
80
+ server_name="0.0.0.0",
81
+ server_port=7860,
82
+ )