Dunateo commited on
Commit
2163f78
·
1 Parent(s): b93f72d

no comments need more coffee

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -12,7 +12,7 @@ def predict(text):
12
 
13
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
14
  predicted_class = torch.argmax(probs, dim=-1).item()
15
- return label_dict[str(predicted_class)], probs[0][predicted_class].item()
16
 
17
  if __name__ == '__main__':
18
  model_path = "Dunateo/roberta-cwe-classifier-kelemia-v0.2"
@@ -24,11 +24,14 @@ if __name__ == '__main__':
24
  # get the dict file
25
  label_dict_file = hf_hub_download(repo_id=model_path, filename="label_dict.json")
26
 
27
- global label_dict
28
  with open(label_dict_file, "r") as f:
29
  content = f.read()
30
  label_dict = json.loads(content)
31
 
 
 
 
32
  # gradio specific to create an IHM
33
  iface = gr.Interface(
34
  fn=predict,
 
12
 
13
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
14
  predicted_class = torch.argmax(probs, dim=-1).item()
15
+ return id2label[predicted_class], probs[0][predicted_class].item()
16
 
17
  if __name__ == '__main__':
18
  model_path = "Dunateo/roberta-cwe-classifier-kelemia-v0.2"
 
24
  # get the dict file
25
  label_dict_file = hf_hub_download(repo_id=model_path, filename="label_dict.json")
26
 
27
+
28
  with open(label_dict_file, "r") as f:
29
  content = f.read()
30
  label_dict = json.loads(content)
31
 
32
+ global id2label
33
+ id2label = {v: k for k, v in label_dict.items()}
34
+
35
  # gradio specific to create an IHM
36
  iface = gr.Interface(
37
  fn=predict,