Spaces:
Sleeping
Sleeping
AlzbetaStrompova
commited on
Commit
·
75a65be
1
Parent(s):
19e9ab7
minor changes
Browse files- app.py +15 -11
- data_manipulation/creation_gazetteers.py +115 -0
- data_manipulation/dataset_funcions.py +124 -212
- data_manipulation/preprocess_gazetteers.py +0 -54
- extended_embeddings/__init__.py +0 -0
- extended_embeddings/{token_classification.py → extended_embedding_token_classification.py} +13 -3
- extended_embeddings/extended_embeddings_data_collator.py +77 -0
- extended_embeddings/extended_embeddings_model.py +12 -39
- flagged/log.csv +0 -8
- requirements.txt +1 -0
- style.css +6 -5
- upload_model.ipynb +3 -3
- website_script.py +32 -4
app.py
CHANGED
|
@@ -1,32 +1,36 @@
|
|
| 1 |
-
import json
|
| 2 |
import gradio as gr
|
| 3 |
from website_script import load, run
|
| 4 |
|
| 5 |
tokenizer, model, gazetteers_for_matching = load()
|
| 6 |
|
| 7 |
examples = [
|
| 8 |
-
["Masarykova univerzita se nachází v
|
| 9 |
-
["Barack Obama navštívil Prahu minulý týden
|
| 10 |
-
["Angela Merkelová se setkala s francouzským prezidentem v Paříži
|
| 11 |
-
["Nobelova cena za fyziku byla udělena týmu vědců z MIT
|
| 12 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def ner(text, file_names):
|
|
|
|
| 16 |
result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
|
| 17 |
return {"text": text, "entities": result}
|
| 18 |
|
| 19 |
with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
|
| 20 |
gr.Interface(ner,
|
| 21 |
-
gr.Textbox(lines=
|
| 22 |
-
|
| 23 |
-
gr.HighlightedText(show_legend=True, color_map={"PER": "#f57d7d", "ORG": "#2cf562", "LOC": "#86aafc"}, elem_id="highlighted_text"),
|
| 24 |
examples=examples,
|
| 25 |
title="NerROB-czech",
|
| 26 |
-
description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
|
| 27 |
allow_flagging="never",
|
| 28 |
additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
|
| 29 |
)
|
| 30 |
|
| 31 |
if __name__ == "__main__":
|
| 32 |
-
demo.launch()
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from website_script import load, run
|
| 3 |
|
| 4 |
tokenizer, model, gazetteers_for_matching = load()
|
| 5 |
|
| 6 |
examples = [
|
| 7 |
+
["Masarykova univerzita se nachází v Brně.", None],
|
| 8 |
+
["Barack Obama navštívil Prahu minulý týden.", None],
|
| 9 |
+
["Angela Merkelová se setkala s francouzským prezidentem v Paříži.", None],
|
| 10 |
+
["Nobelova cena za fyziku byla udělena týmu vědců z MIT.", None],
|
| 11 |
+
["Eiffelova věž je ikonickou památkou v Paříži.", None],
|
| 12 |
+
["Bill Gates, spoluzakladatel společnosti Microsoft, oznámil nový grant pro výzkum umělé inteligence.", None],
|
| 13 |
+
["Britská královna Alžběta II. navštívila Kanadu v rámci svého posledního zahraničního turné, během kterého zdůraznila důležitost spolupráce a přátelství mezi oběma národy.", None],
|
| 14 |
+
["Francouzský prezident Emmanuel Macron oznámil nový plán na podporu start-upů a inovací ve Francii, který zahrnuje investice ve výši několika miliard eur.", None],
|
| 15 |
+
["Světová zdravotnická organizace spustila nový program na boj proti malárii v subsaharské Africe, který zahrnuje rozdělování sítí proti komárům a očkování milionů lidí.", None]
|
| 16 |
+
]
|
| 17 |
|
| 18 |
|
| 19 |
def ner(text, file_names):
|
| 20 |
+
text = text.replace(".", " .")
|
| 21 |
result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
|
| 22 |
return {"text": text, "entities": result}
|
| 23 |
|
| 24 |
with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
|
| 25 |
gr.Interface(ner,
|
| 26 |
+
gr.Textbox(lines=5, placeholder="Enter sentence here..."),
|
| 27 |
+
gr.HighlightedText(show_legend=True, color_map={"PER": "#f7a7a3", "ORG": "#77fc6a", "LOC": "#87CEFF"}),
|
|
|
|
| 28 |
examples=examples,
|
| 29 |
title="NerROB-czech",
|
| 30 |
+
description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
|
| 31 |
allow_flagging="never",
|
| 32 |
additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
|
| 33 |
)
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
+
demo.launch()
|
data_manipulation/creation_gazetteers.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import itertools
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from simplemma import lemmatize
|
| 8 |
+
from names_dataset import NameDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_json(path):
|
| 12 |
+
"""
|
| 13 |
+
Load gazetteers from a file
|
| 14 |
+
:param path: path to the gazetteer file
|
| 15 |
+
:return: a dict of gazetteers
|
| 16 |
+
"""
|
| 17 |
+
with open(path, 'r') as file:
|
| 18 |
+
data = json.load(file)
|
| 19 |
+
return data
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def save_json(data, path):
|
| 23 |
+
"""
|
| 24 |
+
Save gazetteers to a file
|
| 25 |
+
:param path: path to the gazetteer file
|
| 26 |
+
:param gazetteers: a dict of gazetteers
|
| 27 |
+
"""
|
| 28 |
+
with open(path, 'w') as file:
|
| 29 |
+
json.dump(data, file, indent=4)
|
| 30 |
+
|
| 31 |
+
def merge_gazetteers(*gazetteers):
|
| 32 |
+
"""
|
| 33 |
+
Merge multiple gazetteer dictionaries into a single gazetteer dictionary.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
dict: A merged gazetteer dictionary containing all the keys and values from the input gazetteers.
|
| 37 |
+
"""
|
| 38 |
+
# Initialize a new dictionary to store merged results
|
| 39 |
+
merged_gazetteers = {}
|
| 40 |
+
# Iterate over each dictionary provided
|
| 41 |
+
for gaz in gazetteers:
|
| 42 |
+
# Iterate over each key and set in the current dictionary
|
| 43 |
+
for key, value_set in gaz.items():
|
| 44 |
+
if key in merged_gazetteers:
|
| 45 |
+
# If the key already exists in the result, union the sets
|
| 46 |
+
merged_gazetteers[key] |= value_set
|
| 47 |
+
else:
|
| 48 |
+
# Otherwise, initialize the key with the set from the current dictionary
|
| 49 |
+
merged_gazetteers[key] = value_set.copy() # Use copy to avoid mutating the original sets
|
| 50 |
+
return merged_gazetteers
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
####################################################################################################
|
| 54 |
+
### PREPROCESSING OF GAZETTEERS ###################################################################
|
| 55 |
+
####################################################################################################
|
| 56 |
+
|
| 57 |
+
def remove_all_brackets(text):
|
| 58 |
+
return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def lemmatizing(x):
|
| 62 |
+
if x == "":
|
| 63 |
+
return ""
|
| 64 |
+
return lemmatize(x, lang="cs")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def multi_lemmatizing(x):
|
| 68 |
+
words = x.split(" ")
|
| 69 |
+
phrase = ""
|
| 70 |
+
for word in words:
|
| 71 |
+
phrase += lemmatizing(word) + " "
|
| 72 |
+
return phrase.strip()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
| 76 |
+
reverse_dictionary = {}
|
| 77 |
+
for key, values in dictionary.items():
|
| 78 |
+
for value in values:
|
| 79 |
+
reverse_dictionary[value] = key
|
| 80 |
+
if apply_lemmatizing:
|
| 81 |
+
temp = lemmatizing(value)
|
| 82 |
+
if temp != value:
|
| 83 |
+
reverse_dictionary[temp] = key
|
| 84 |
+
return reverse_dictionary
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def split_gazetteers_for_single_token_match(gazetteers):
|
| 88 |
+
result = {}
|
| 89 |
+
for k, v in gazetteers.items():
|
| 90 |
+
result[k] = set([x for xs in [vv.split(" ") for vv in v] for x in xs])
|
| 91 |
+
result[k] = {x for x in result[k] if len(x) > 2}
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def preprocess_gazetteers(gazetteers, config):
|
| 96 |
+
if config["remove_brackets"]:
|
| 97 |
+
for k, values in gazetteers.items():
|
| 98 |
+
gazetteers[k] = {remove_all_brackets(vv).strip() for vv in values if len(remove_all_brackets(vv).strip()) > 2}
|
| 99 |
+
if config["split_person"]:
|
| 100 |
+
gazetteers["per"].update(set([x for x in list(itertools.chain(*[v.split(" ") for v in gazetteers["per"]])) if len(x) > 2]))
|
| 101 |
+
if config["techniq_for_matching"] == "single":
|
| 102 |
+
gazetteers = split_gazetteers_for_single_token_match(gazetteers)
|
| 103 |
+
if config["lemmatize"]:
|
| 104 |
+
for k, values in gazetteers.items():
|
| 105 |
+
gazetteers[k] = set(list(itertools.chain(*[(vv, lemmatizing(vv)) for vv in values if len(vv) > 2])))
|
| 106 |
+
elif config["lemmatize"]:
|
| 107 |
+
for k, values in gazetteers.items():
|
| 108 |
+
gazetteers[k] = set(list(itertools.chain(*[(value, multi_lemmatizing(value)) for value in values if len(value) > 2])))
|
| 109 |
+
|
| 110 |
+
if config["remove_numeric"]:
|
| 111 |
+
for k, values in gazetteers.items():
|
| 112 |
+
gazetteers[k] = {vv for vv in values if not vv.isnumeric()}
|
| 113 |
+
for k, values in gazetteers.items():
|
| 114 |
+
gazetteers[k] = list(values)
|
| 115 |
+
return gazetteers
|
data_manipulation/dataset_funcions.py
CHANGED
|
@@ -1,27 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
-
import json
|
| 4 |
from tqdm import tqdm
|
| 5 |
|
| 6 |
from datasets import Dataset, DatasetDict
|
| 7 |
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
Load gazetteers from a file
|
| 11 |
-
:param path: path to the gazetteer file
|
| 12 |
-
:return: a dict of gazetteers
|
| 13 |
-
"""
|
| 14 |
-
with open(path, 'r') as f:
|
| 15 |
-
gazetteers = json.load(f)
|
| 16 |
-
for k, v in gazetteers.items():
|
| 17 |
-
gazetteers[k] = set(v)
|
| 18 |
-
return gazetteers
|
| 19 |
-
|
| 20 |
-
def create_dataset(label_mapper:dict, args):
|
| 21 |
-
if args.dataset == "cnec":
|
| 22 |
-
return create_cnec_dataset(label_mapper, args)
|
| 23 |
-
return load_wikiann_testing_dataset(args)
|
| 24 |
-
|
| 25 |
|
| 26 |
####################################################################################################
|
| 27 |
### GAZETTEERS EMBEDDINGS ##########################################################################
|
|
@@ -43,26 +26,36 @@ def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
|
|
| 43 |
i += 1
|
| 44 |
return matches
|
| 45 |
|
| 46 |
-
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
|
| 47 |
-
return matches
|
| 48 |
|
| 49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
return matches
|
| 51 |
|
| 52 |
-
def gazetteer_matching(words, gazetteers_for_matching):
|
| 53 |
-
single_token_match = False
|
| 54 |
-
ending_ova = False
|
| 55 |
-
apply_lemmatizing = False
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
else: # multi_token_match
|
| 62 |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
result = []
|
| 68 |
for word in words:
|
|
@@ -70,72 +63,18 @@ def gazetteer_matching(words, gazetteers_for_matching):
|
|
| 70 |
per, org, loc = 0, 0, 0
|
| 71 |
for res in mid_res:
|
| 72 |
if mid_res[0][0].count(" ") == res[0].count(" "):
|
| 73 |
-
if res[1] == "
|
| 74 |
-
per =
|
| 75 |
-
elif res[1] == "
|
| 76 |
-
org =
|
| 77 |
-
elif res[1] == "
|
| 78 |
-
loc =
|
| 79 |
if ending_ova and word.endswith("ová") and word[0].isupper():
|
| 80 |
-
per =
|
| 81 |
result.append([per, org, loc])
|
| 82 |
return result
|
| 83 |
|
| 84 |
|
| 85 |
-
####################################################################################################
|
| 86 |
-
### GAZETTEERS EXPANSION TRAIN DATASET #############################################################
|
| 87 |
-
####################################################################################################
|
| 88 |
-
|
| 89 |
-
def expand_train_dataset_with_gazetteers(train, args):
|
| 90 |
-
if args.apply_extended_embeddings:
|
| 91 |
-
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
|
| 92 |
-
gazetteers = load_gazetteers(args.train_gazetteers_path)
|
| 93 |
-
count_gazetteers = {}
|
| 94 |
-
id_ = train[-1]["id"]
|
| 95 |
-
dataset = []
|
| 96 |
-
for row in train:
|
| 97 |
-
dataset.append({"id": row['id'], 'tokens': row['tokens'].copy(),
|
| 98 |
-
'ner_tags': row['ner_tags'].copy(), 'gazetteers': row['gazetteers'].copy()})
|
| 99 |
-
for k in gazetteers.keys():
|
| 100 |
-
count_gazetteers[k] = 0
|
| 101 |
-
for index in range(args.gazetteers_counter):
|
| 102 |
-
for row in tqdm(train, desc=f"loop {index} from {args.gazetteers_counter}"):
|
| 103 |
-
i = 0
|
| 104 |
-
temp_1 = row["ner_tags"].copy()
|
| 105 |
-
temp_2 = row["tokens"].copy()
|
| 106 |
-
if temp_1.count(0) == len(temp_1):
|
| 107 |
-
continue
|
| 108 |
-
while i < len(temp_1):
|
| 109 |
-
tag = temp_1[i]
|
| 110 |
-
if tag % 2 == 1:
|
| 111 |
-
tags = temp_1[:i]
|
| 112 |
-
tokens = temp_2[:i]
|
| 113 |
-
i += 1
|
| 114 |
-
assert len(gazetteers[tag]) > count_gazetteers[tag]
|
| 115 |
-
new = gazetteers[tag][count_gazetteers[tag]].split(" ")
|
| 116 |
-
count_gazetteers[tag] += 1
|
| 117 |
-
while i < len(temp_1):
|
| 118 |
-
if temp_1[i] != tag + 1:
|
| 119 |
-
break
|
| 120 |
-
i += 1
|
| 121 |
-
tags.append(tag)
|
| 122 |
-
tags.extend([tag + 1] * (len(new) - 1))
|
| 123 |
-
tags.extend(temp_1[i:])
|
| 124 |
-
|
| 125 |
-
tokens.extend(new)
|
| 126 |
-
tokens.extend(temp_2[i:])
|
| 127 |
-
temp_1 = tags
|
| 128 |
-
temp_2 = tokens
|
| 129 |
-
else:
|
| 130 |
-
i += 1
|
| 131 |
-
id_ += 1
|
| 132 |
-
if args.apply_extended_embeddings:
|
| 133 |
-
matching = gazetteer_matching(temp_2, gazetteers_for_matching, args)
|
| 134 |
-
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1, "gazetteers": matching})
|
| 135 |
-
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1})
|
| 136 |
-
return dataset
|
| 137 |
-
|
| 138 |
-
|
| 139 |
####################################################################################################
|
| 140 |
### CNEC DATASET ###################################################################################
|
| 141 |
####################################################################################################
|
|
@@ -144,7 +83,6 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
| 144 |
label_mapper: cnec labels to int
|
| 145 |
"""
|
| 146 |
# Open and read the XML file as plain text
|
| 147 |
-
assert os.path.isfile(xml_file_path)
|
| 148 |
id_ = 0
|
| 149 |
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
| 150 |
plain_text = xml_file.read()
|
|
@@ -156,14 +94,13 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
| 156 |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
| 157 |
data = []
|
| 158 |
if args.apply_extended_embeddings:
|
| 159 |
-
gazetteers_for_matching =
|
| 160 |
-
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
|
| 161 |
temp = []
|
| 162 |
for i in gazetteers_for_matching.keys():
|
| 163 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
| 164 |
gazetteers_for_matching = temp
|
| 165 |
|
| 166 |
-
for sentence in tqdm(sentences):
|
| 167 |
entity_mapping = []
|
| 168 |
while "<ne type=" in sentence: # while because there are nested entities
|
| 169 |
nes = re.findall(ne_pattern, sentence)
|
|
@@ -215,7 +152,7 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
| 215 |
if tags_per_word == [] or tags_per_word == [0]:
|
| 216 |
continue
|
| 217 |
if args.apply_extended_embeddings:
|
| 218 |
-
matching = gazetteer_matching(words, gazetteers_for_matching)
|
| 219 |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
|
| 220 |
"sentence": " ".join(words), "gazetteers": matching})
|
| 221 |
else:
|
|
@@ -223,104 +160,78 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
| 223 |
id_ += 1
|
| 224 |
return data
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
for i in gazetteers_for_matching.keys():
|
| 248 |
-
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
| 249 |
-
gazetteers_for_matching = temp
|
| 250 |
-
|
| 251 |
-
for sentence in tqdm(sentences):
|
| 252 |
-
entity_mapping = []
|
| 253 |
-
while "<ne type=" in sentence: # while because there are nested entities
|
| 254 |
-
nes = re.findall(ne_pattern, sentence)
|
| 255 |
-
for label, entity in nes:
|
| 256 |
-
pattern = f'<ne type="{label}">{entity}</ne>'
|
| 257 |
-
index = sentence.index(pattern)
|
| 258 |
-
temp_index = index
|
| 259 |
-
sentence = sentence.replace(pattern, entity, 1)
|
| 260 |
-
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
|
| 261 |
-
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
|
| 262 |
-
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
|
| 263 |
-
index = temp_index
|
| 264 |
-
entity_mapping.append((entity, label, index, index + len(entity)))
|
| 265 |
-
|
| 266 |
-
entities = []
|
| 267 |
-
for entity, label, start, end in entity_mapping:
|
| 268 |
-
for tag in label_mapper.keys():
|
| 269 |
-
if label.lower().startswith(tag):
|
| 270 |
-
entities.append((label_mapper[tag], entity, start, end))
|
| 271 |
-
break
|
| 272 |
-
entities.sort(key=lambda x: len(x[1]), reverse=True)
|
| 273 |
-
|
| 274 |
-
words = re.split(r'\s+', sentence)
|
| 275 |
-
tags_per_word = []
|
| 276 |
-
sentence_counter = -1
|
| 277 |
-
for word in words:
|
| 278 |
-
sentence_counter += len(word) + 1
|
| 279 |
-
if len(entities) == 0:
|
| 280 |
-
tags_per_word.append(0) # tag representing no label for no word
|
| 281 |
-
for index_entity in range(len(entities)):
|
| 282 |
-
if not(sentence_counter - len(word) >= entities[index_entity][2] and
|
| 283 |
-
sentence_counter <= entities[index_entity][3] and
|
| 284 |
-
word in entities[index_entity][1]):
|
| 285 |
-
if index_entity == len(entities) - 1:
|
| 286 |
-
tags_per_word.append(0) # tag representing no label for word
|
| 287 |
-
continue
|
| 288 |
-
|
| 289 |
-
if True:
|
| 290 |
-
if sentence_counter - len(word) == entities[index_entity][2]:
|
| 291 |
-
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
|
| 292 |
-
else:
|
| 293 |
-
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
|
| 294 |
-
else:
|
| 295 |
-
tags_per_word.append(entities[index_entity][0])
|
| 296 |
break
|
|
|
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
def create_cnec_dataset(label_mapper:dict, args):
|
| 315 |
-
|
| 316 |
-
assert os.path.isdir(args.cnec_dataset_dir_path)
|
| 317 |
dataset = DatasetDict()
|
| 318 |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
|
| 319 |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
|
| 320 |
-
assert os.path.isfile(file_path)
|
| 321 |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
|
| 322 |
-
if args.expand_train_data:
|
| 323 |
-
temp_dataset = expand_train_dataset_with_gazetteers(temp_dataset, args)
|
| 324 |
dataset[part] = Dataset.from_list(temp_dataset)
|
| 325 |
return dataset
|
| 326 |
|
|
@@ -328,16 +239,19 @@ def create_cnec_dataset(label_mapper:dict, args):
|
|
| 328 |
### WIKIANN DATASET ################################################################################
|
| 329 |
####################################################################################################
|
| 330 |
def load_wikiann_testing_dataset(args):
|
| 331 |
-
if args.
|
| 332 |
-
gazetteers_for_matching =
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
| 334 |
dataset = []
|
| 335 |
index = 0
|
| 336 |
sentences = load_tagged_sentences(args.wikiann_dataset_path)
|
| 337 |
for sentence in sentences:
|
| 338 |
words = [word for word, _ in sentence]
|
| 339 |
tags = [tag for _, tag in sentence]
|
| 340 |
-
if args.
|
| 341 |
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
| 342 |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
|
| 343 |
else:
|
|
@@ -345,9 +259,10 @@ def load_wikiann_testing_dataset(args):
|
|
| 345 |
index += 1
|
| 346 |
|
| 347 |
test = Dataset.from_list(dataset)
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
| 351 |
return dataset
|
| 352 |
|
| 353 |
|
|
@@ -400,26 +315,24 @@ def align_labels_with_tokens(labels, word_ids):
|
|
| 400 |
new_labels.append(label)
|
| 401 |
return new_labels
|
| 402 |
|
|
|
|
| 403 |
def align_gazetteers_with_tokens(gazetteers, word_ids):
|
| 404 |
-
|
| 405 |
current_word = None
|
| 406 |
for word_id in word_ids:
|
| 407 |
if word_id != current_word:
|
| 408 |
# Start of a new word!
|
| 409 |
current_word = word_id
|
| 410 |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
|
| 411 |
-
|
| 412 |
elif word_id is None:
|
| 413 |
# Special token
|
| 414 |
-
|
| 415 |
else:
|
| 416 |
# Same word as previous token
|
| 417 |
gazetteer = gazetteers[word_id]
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
# gazetteer += 1
|
| 421 |
-
new_g.append(gazetteer)
|
| 422 |
-
return new_g
|
| 423 |
|
| 424 |
|
| 425 |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
|
|
@@ -434,25 +347,24 @@ def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=T
|
|
| 434 |
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
| 435 |
tokenized_inputs["labels"] = new_labels
|
| 436 |
if apply_extended_embeddings:
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
for i,
|
| 440 |
word_ids = tokenized_inputs.word_ids(i)
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
for i in
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
tokenized_inputs["per"] =
|
| 448 |
-
tokenized_inputs["org"] =
|
| 449 |
-
tokenized_inputs["loc"] =
|
| 450 |
return tokenized_inputs
|
| 451 |
|
| 452 |
-
|
| 453 |
dataset = raw_dataset.map(
|
| 454 |
tokenize_and_align_labels,
|
| 455 |
batched=True,
|
| 456 |
-
remove_columns=raw_dataset["train"].column_names
|
| 457 |
)
|
| 458 |
return dataset
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
|
|
|
| 3 |
from tqdm import tqdm
|
| 4 |
|
| 5 |
from datasets import Dataset, DatasetDict
|
| 6 |
|
| 7 |
+
from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
####################################################################################################
|
| 10 |
### GAZETTEERS EMBEDDINGS ##########################################################################
|
|
|
|
| 26 |
i += 1
|
| 27 |
return matches
|
| 28 |
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
|
| 31 |
+
n = len(tokens)
|
| 32 |
+
assert n == len(looking_tokens)
|
| 33 |
+
for index in range(n):
|
| 34 |
+
word = looking_tokens[index]
|
| 35 |
+
if len(word) < 3:
|
| 36 |
+
continue
|
| 37 |
+
for gazetteer in gazetteers:
|
| 38 |
+
if word in gazetteer:
|
| 39 |
+
match_type = gazetteer[word]
|
| 40 |
+
matches.setdefault(tokens[index], []).append((word, match_type))
|
| 41 |
return matches
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def gazetteer_matching(words, gazetteers_for_matching, args=None):
|
| 45 |
+
ending_ova = True
|
| 46 |
+
method_for_gazetteers_matching = "single"
|
| 47 |
+
apply_lemmatizing = True
|
| 48 |
|
| 49 |
+
if method_for_gazetteers_matching == "single":
|
| 50 |
+
matches = find_single_token_matches(words, words, gazetteers_for_matching, {})
|
| 51 |
+
if apply_lemmatizing:
|
| 52 |
+
lemmatize_tokens = [lemmatizing(t) for t in words]
|
| 53 |
+
matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
|
| 54 |
else: # multi_token_match
|
| 55 |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
|
| 56 |
+
if apply_lemmatizing:
|
| 57 |
+
lemmatize_tokens = [lemmatizing(t) for t in words]
|
| 58 |
+
matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
|
| 59 |
|
| 60 |
result = []
|
| 61 |
for word in words:
|
|
|
|
| 63 |
per, org, loc = 0, 0, 0
|
| 64 |
for res in mid_res:
|
| 65 |
if mid_res[0][0].count(" ") == res[0].count(" "):
|
| 66 |
+
if res[1] == "PER":
|
| 67 |
+
per = 5
|
| 68 |
+
elif res[1] == "ORG":
|
| 69 |
+
org = 5
|
| 70 |
+
elif res[1] == "LOC":
|
| 71 |
+
loc = 5
|
| 72 |
if ending_ova and word.endswith("ová") and word[0].isupper():
|
| 73 |
+
per = 5
|
| 74 |
result.append([per, org, loc])
|
| 75 |
return result
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
####################################################################################################
|
| 79 |
### CNEC DATASET ###################################################################################
|
| 80 |
####################################################################################################
|
|
|
|
| 83 |
label_mapper: cnec labels to int
|
| 84 |
"""
|
| 85 |
# Open and read the XML file as plain text
|
|
|
|
| 86 |
id_ = 0
|
| 87 |
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
| 88 |
plain_text = xml_file.read()
|
|
|
|
| 94 |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
| 95 |
data = []
|
| 96 |
if args.apply_extended_embeddings:
|
| 97 |
+
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
|
|
|
|
| 98 |
temp = []
|
| 99 |
for i in gazetteers_for_matching.keys():
|
| 100 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
| 101 |
gazetteers_for_matching = temp
|
| 102 |
|
| 103 |
+
for sentence in tqdm(sentences):
|
| 104 |
entity_mapping = []
|
| 105 |
while "<ne type=" in sentence: # while because there are nested entities
|
| 106 |
nes = re.findall(ne_pattern, sentence)
|
|
|
|
| 152 |
if tags_per_word == [] or tags_per_word == [0]:
|
| 153 |
continue
|
| 154 |
if args.apply_extended_embeddings:
|
| 155 |
+
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
| 156 |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
|
| 157 |
"sentence": " ".join(words), "gazetteers": matching})
|
| 158 |
else:
|
|
|
|
| 160 |
id_ += 1
|
| 161 |
return data
|
| 162 |
|
| 163 |
+
def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path):
|
| 164 |
+
"""
|
| 165 |
+
label_mapper: cnec labels to int
|
| 166 |
+
"""
|
| 167 |
+
# Open and read the XML file as plain text
|
| 168 |
+
id_ = 0
|
| 169 |
+
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
| 170 |
+
plain_text = xml_file.read()
|
| 171 |
+
plain_text = plain_text[5:-5] # remove unnessery characters
|
| 172 |
+
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
|
| 173 |
+
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
|
| 174 |
+
plain_text = re.sub(r'[ ]+', ' ', plain_text)
|
| 175 |
+
sentences = plain_text.split("\n")
|
| 176 |
+
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
| 177 |
+
data = []
|
| 178 |
|
| 179 |
+
for sentence in tqdm(sentences):
|
| 180 |
+
entity_mapping = []
|
| 181 |
+
while "<ne type=" in sentence: # while because there are nested entities
|
| 182 |
+
nes = re.findall(ne_pattern, sentence)
|
| 183 |
+
for label, entity in nes:
|
| 184 |
+
pattern = f'<ne type="{label}">{entity}</ne>'
|
| 185 |
+
index = sentence.index(pattern)
|
| 186 |
+
temp_index = index
|
| 187 |
+
sentence = sentence.replace(pattern, entity, 1)
|
| 188 |
+
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
|
| 189 |
+
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
|
| 190 |
+
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
|
| 191 |
+
index = temp_index
|
| 192 |
+
entity_mapping.append((entity, label, index, index + len(entity)))
|
| 193 |
+
|
| 194 |
+
entities = []
|
| 195 |
+
for entity, label, start, end in entity_mapping:
|
| 196 |
+
for tag in label_mapper.keys():
|
| 197 |
+
if label.lower().startswith(tag):
|
| 198 |
+
entities.append((label_mapper[tag], entity, start, end))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
break
|
| 200 |
+
entities.sort(key=lambda x: len(x[1]), reverse=True)
|
| 201 |
|
| 202 |
+
words = re.split(r'\s+', sentence)
|
| 203 |
+
tags_per_word = []
|
| 204 |
+
sentence_counter = -1
|
| 205 |
+
for word in words:
|
| 206 |
+
sentence_counter += len(word) + 1
|
| 207 |
+
if len(entities) == 0:
|
| 208 |
+
tags_per_word.append(0) # tag representing no label for no word
|
| 209 |
+
for index_entity in range(len(entities)):
|
| 210 |
+
if not(sentence_counter - len(word) >= entities[index_entity][2] and
|
| 211 |
+
sentence_counter <= entities[index_entity][3] and
|
| 212 |
+
word in entities[index_entity][1]):
|
| 213 |
+
if index_entity == len(entities) - 1:
|
| 214 |
+
tags_per_word.append(0) # tag representing no label for word
|
| 215 |
+
continue
|
| 216 |
|
| 217 |
+
if sentence_counter - len(word) == entities[index_entity][2]:
|
| 218 |
+
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
|
| 219 |
+
else:
|
| 220 |
+
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
|
| 221 |
|
| 222 |
+
if tags_per_word == [] or tags_per_word == [0]:
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
|
| 226 |
+
id_ += 1
|
| 227 |
+
return data
|
| 228 |
|
| 229 |
|
| 230 |
def create_cnec_dataset(label_mapper:dict, args):
|
|
|
|
|
|
|
| 231 |
dataset = DatasetDict()
|
| 232 |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
|
| 233 |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
|
|
|
|
| 234 |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
|
|
|
|
|
|
|
| 235 |
dataset[part] = Dataset.from_list(temp_dataset)
|
| 236 |
return dataset
|
| 237 |
|
|
|
|
| 239 |
### WIKIANN DATASET ################################################################################
|
| 240 |
####################################################################################################
|
| 241 |
def load_wikiann_testing_dataset(args):
|
| 242 |
+
if args.apply_extended_embeddings:
|
| 243 |
+
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
|
| 244 |
+
temp = []
|
| 245 |
+
for i in gazetteers_for_matching.keys():
|
| 246 |
+
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
| 247 |
+
gazetteers_for_matching = temp
|
| 248 |
dataset = []
|
| 249 |
index = 0
|
| 250 |
sentences = load_tagged_sentences(args.wikiann_dataset_path)
|
| 251 |
for sentence in sentences:
|
| 252 |
words = [word for word, _ in sentence]
|
| 253 |
tags = [tag for _, tag in sentence]
|
| 254 |
+
if args.apply_extended_embeddings:
|
| 255 |
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
| 256 |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
|
| 257 |
else:
|
|
|
|
| 259 |
index += 1
|
| 260 |
|
| 261 |
test = Dataset.from_list(dataset)
|
| 262 |
+
dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
|
| 263 |
+
"validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
|
| 264 |
+
"test": test})
|
| 265 |
+
# dataset = DatasetDict({"test": test})
|
| 266 |
return dataset
|
| 267 |
|
| 268 |
|
|
|
|
| 315 |
new_labels.append(label)
|
| 316 |
return new_labels
|
| 317 |
|
| 318 |
+
|
| 319 |
def align_gazetteers_with_tokens(gazetteers, word_ids):
|
| 320 |
+
aligned_gazetteers = []
|
| 321 |
current_word = None
|
| 322 |
for word_id in word_ids:
|
| 323 |
if word_id != current_word:
|
| 324 |
# Start of a new word!
|
| 325 |
current_word = word_id
|
| 326 |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
|
| 327 |
+
aligned_gazetteers.append(gazetteer)
|
| 328 |
elif word_id is None:
|
| 329 |
# Special token
|
| 330 |
+
aligned_gazetteers.append([0,0,0])
|
| 331 |
else:
|
| 332 |
# Same word as previous token
|
| 333 |
gazetteer = gazetteers[word_id]
|
| 334 |
+
aligned_gazetteers.append(gazetteer)
|
| 335 |
+
return aligned_gazetteers
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
|
| 338 |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
|
|
|
|
| 347 |
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
| 348 |
tokenized_inputs["labels"] = new_labels
|
| 349 |
if apply_extended_embeddings:
|
| 350 |
+
matches = examples["gazetteers"]
|
| 351 |
+
aligned_matches = []
|
| 352 |
+
for i, match in enumerate(matches):
|
| 353 |
word_ids = tokenized_inputs.word_ids(i)
|
| 354 |
+
aligned_matches.append(align_gazetteers_with_tokens(match, word_ids))
|
| 355 |
+
per, org, loc = [], [], []
|
| 356 |
+
for i in aligned_matches:
|
| 357 |
+
per.append([x[0] for x in i])
|
| 358 |
+
org.append([x[1] for x in i])
|
| 359 |
+
loc.append([x[2] for x in i])
|
| 360 |
+
tokenized_inputs["per"] = per
|
| 361 |
+
tokenized_inputs["org"] = org
|
| 362 |
+
tokenized_inputs["loc"] = loc
|
| 363 |
return tokenized_inputs
|
| 364 |
|
|
|
|
| 365 |
dataset = raw_dataset.map(
|
| 366 |
tokenize_and_align_labels,
|
| 367 |
batched=True,
|
| 368 |
+
# remove_columns=raw_dataset["train"].column_names
|
| 369 |
)
|
| 370 |
return dataset
|
data_manipulation/preprocess_gazetteers.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
from simplemma import lemmatize
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def flatten(xss):
|
| 7 |
-
return [x for xs in xss for x in xs]
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def remove_all_brackets(text):
|
| 11 |
-
return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def lemmatizing(x):
|
| 15 |
-
if x == "":
|
| 16 |
-
return ""
|
| 17 |
-
return lemmatize(x, lang="cs")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
| 21 |
-
reverse_dictionary = {}
|
| 22 |
-
for key, values in dictionary.items():
|
| 23 |
-
for value in values:
|
| 24 |
-
reverse_dictionary[value] = key
|
| 25 |
-
if apply_lemmatizing:
|
| 26 |
-
temp = lemmatizing(value)
|
| 27 |
-
if temp != value:
|
| 28 |
-
reverse_dictionary[temp] = key
|
| 29 |
-
return reverse_dictionary
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def split_gazetteers_for_single_token_match(gazetteers):
|
| 33 |
-
result = {}
|
| 34 |
-
for k, v in gazetteers.items():
|
| 35 |
-
result[k] = set(flatten([vv.split(" ") for vv in v]))
|
| 36 |
-
result[k] = {x for x in result[k] if len(x) > 2}
|
| 37 |
-
return result
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def preprocess_gazetteers(gazetteers, config):
|
| 41 |
-
if config["split_person"]:
|
| 42 |
-
gazetteers["PER"].update(set([x for x in flatten([v.split(" ") for v in gazetteers["PER"]]) if len(x) > 2]))
|
| 43 |
-
if config["lemmatize"]:
|
| 44 |
-
for k, v in gazetteers.items():
|
| 45 |
-
gazetteers[k] = set(flatten([(vv, lemmatizing(vv)) for vv in v if len(vv) > 2]))
|
| 46 |
-
if config["remove_brackets"]:
|
| 47 |
-
for k, v in gazetteers.items():
|
| 48 |
-
gazetteers[k] = {remove_all_brackets(vv).strip() for vv in v if len(remove_all_brackets(vv).strip()) > 2}
|
| 49 |
-
if config["remove_numeric"]:
|
| 50 |
-
for k, v in gazetteers.items():
|
| 51 |
-
gazetteers[k] = {vv for vv in v if not vv.isnumeric()}
|
| 52 |
-
if config["techniq_for_matching"] != "single":
|
| 53 |
-
gazetteers = split_gazetteers_for_single_token_match(gazetteers)
|
| 54 |
-
return gazetteers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_embeddings/__init__.py
DELETED
|
File without changes
|
extended_embeddings/{token_classification.py → extended_embedding_token_classification.py}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from torch import nn
|
|
@@ -12,11 +12,20 @@ _CONFIG_FOR_DOC = "RobertaConfig"
|
|
| 12 |
|
| 13 |
|
| 14 |
class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(self, config):
|
| 16 |
super().__init__(config)
|
| 17 |
self.num_labels = config.num_labels
|
| 18 |
|
| 19 |
-
self.roberta = ExtendedEmbeddigsRobertaModel(config
|
| 20 |
classifier_dropout = (
|
| 21 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 22 |
)
|
|
@@ -92,4 +101,5 @@ class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassificati
|
|
| 92 |
logits=logits,
|
| 93 |
hidden_states=outputs.hidden_states,
|
| 94 |
attentions=outputs.attentions,
|
| 95 |
-
)
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from torch import nn
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
|
| 15 |
+
"""
|
| 16 |
+
A RobertaForTokenClassification for token classification tasks with extended embeddings.
|
| 17 |
+
|
| 18 |
+
This RobertaForTokenClassification extends the functionality of the `RobertaForTokenClassification` class
|
| 19 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
| 20 |
+
|
| 21 |
+
Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaForTokenClassification
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__(config)
|
| 26 |
self.num_labels = config.num_labels
|
| 27 |
|
| 28 |
+
self.roberta = ExtendedEmbeddigsRobertaModel(config)
|
| 29 |
classifier_dropout = (
|
| 30 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 31 |
)
|
|
|
|
| 101 |
logits=logits,
|
| 102 |
hidden_states=outputs.hidden_states,
|
| 103 |
attentions=outputs.attentions,
|
| 104 |
+
)
|
| 105 |
+
|
extended_embeddings/extended_embeddings_data_collator.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import DataCollatorForTokenClassification
|
| 3 |
+
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification):
|
| 7 |
+
"""
|
| 8 |
+
A data collator for token classification tasks with extended embeddings.
|
| 9 |
+
|
| 10 |
+
This data collator extends the functionality of the `DataCollatorForTokenClassification` class
|
| 11 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
| 12 |
+
|
| 13 |
+
Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def torch_call(self, features):
|
| 17 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 18 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 19 |
+
per = [feature["per"] for feature in features] if "per" in features[0].keys() else None
|
| 20 |
+
org = [feature["org"] for feature in features] if "org" in features[0].keys() else None
|
| 21 |
+
loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None
|
| 22 |
+
|
| 23 |
+
no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features]
|
| 24 |
+
|
| 25 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 26 |
+
self.tokenizer,
|
| 27 |
+
no_labels_features,
|
| 28 |
+
padding=self.padding,
|
| 29 |
+
max_length=self.max_length,
|
| 30 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 31 |
+
return_tensors="pt",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if labels is None:
|
| 35 |
+
return batch
|
| 36 |
+
|
| 37 |
+
sequence_length = batch["input_ids"].shape[1]
|
| 38 |
+
padding_side = self.tokenizer.padding_side
|
| 39 |
+
|
| 40 |
+
def to_list(tensor_or_iterable):
|
| 41 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
| 42 |
+
return tensor_or_iterable.tolist()
|
| 43 |
+
return list(tensor_or_iterable)
|
| 44 |
+
|
| 45 |
+
if padding_side == "right":
|
| 46 |
+
batch[label_name] = [
|
| 47 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 48 |
+
]
|
| 49 |
+
batch["per"] = [
|
| 50 |
+
to_list(p) + [0] * (sequence_length - len(p)) for p in per
|
| 51 |
+
]
|
| 52 |
+
batch["org"] = [
|
| 53 |
+
to_list(o) + [0] * (sequence_length - len(o)) for o in org
|
| 54 |
+
]
|
| 55 |
+
batch["loc"] = [
|
| 56 |
+
to_list(l) + [0] * (sequence_length - len(l)) for l in loc
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
batch[label_name] = [
|
| 60 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
| 61 |
+
]
|
| 62 |
+
batch["per"] = [
|
| 63 |
+
[0] * (sequence_length - len(p)) + self.to_list(p) for p in per
|
| 64 |
+
]
|
| 65 |
+
batch["org"] = [
|
| 66 |
+
[0] * (sequence_length - len(o)) + self.to_list(o) for o in org
|
| 67 |
+
]
|
| 68 |
+
batch["loc"] = [
|
| 69 |
+
[0] * (sequence_length - len(l)) + self.to_list(l) for l in loc
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
| 73 |
+
batch["per"] = torch.tensor(batch["per"], dtype=torch.int64)
|
| 74 |
+
batch["org"] = torch.tensor(batch["org"], dtype=torch.int64)
|
| 75 |
+
batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64)
|
| 76 |
+
return batch
|
| 77 |
+
|
extended_embeddings/extended_embeddings_model.py
CHANGED
|
@@ -1,53 +1,27 @@
|
|
| 1 |
-
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
|
| 2 |
-
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
| 3 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 4 |
import torch
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
|
| 8 |
-
# Copied from transformers.models.bert.modeling_bert.BertPooler
|
| 9 |
-
class ExtendedEmbeddigsRobertaPooler(nn.Module):
|
| 10 |
-
def __init__(self, config):
|
| 11 |
-
super().__init__()
|
| 12 |
-
size_of_gazetters_part = int((len(config.id2label.keys()) - 1) // 2)
|
| 13 |
-
self.dense = nn.Linear(config.hidden_size + size_of_gazetters_part, config.hidden_size + size_of_gazetters_part)
|
| 14 |
-
self.activation = nn.Tanh()
|
| 15 |
-
|
| 16 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 17 |
-
# We "pool" the model by simply taking the hidden state corresponding
|
| 18 |
-
# to the first token.
|
| 19 |
-
first_token_tensor = hidden_states[:, 0]
|
| 20 |
-
pooled_output = self.dense(first_token_tensor)
|
| 21 |
-
pooled_output = self.activation(pooled_output)
|
| 22 |
-
return pooled_output
|
| 23 |
|
| 24 |
class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
| 25 |
"""
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
| 30 |
-
Kaiser and Illia Polosukhin.
|
| 31 |
|
| 32 |
-
|
| 33 |
-
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 34 |
-
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 35 |
-
|
| 36 |
-
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
| 37 |
|
| 38 |
"""
|
| 39 |
-
|
| 40 |
-
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
| 41 |
-
def __init__(self, config, add_pooling_layer=True):
|
| 42 |
super().__init__(config)
|
| 43 |
self.config = config
|
| 44 |
|
| 45 |
self.embeddings = RobertaEmbeddings(config)
|
| 46 |
self.encoder = RobertaEncoder(config)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
self.pooler = ExtendedEmbeddigsRobertaPooler(config)
|
| 50 |
-
|
| 51 |
# Initialize weights and apply final processing
|
| 52 |
self.post_init()
|
| 53 |
|
|
@@ -57,10 +31,9 @@ class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
|
| 57 |
attention_mask: Optional[torch.Tensor] = None,
|
| 58 |
token_type_ids: Optional[torch.Tensor] = None,
|
| 59 |
position_ids: Optional[torch.Tensor] = None,
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
loc: Optional[torch.Tensor] = None, # change
|
| 64 |
head_mask: Optional[torch.Tensor] = None,
|
| 65 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 66 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
| 5 |
+
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
| 9 |
"""
|
| 10 |
+
A RobertaModel for token classification tasks with extended embeddings.
|
| 11 |
|
| 12 |
+
This RobertaModel extends the functionality of the `RobertaModel` class
|
| 13 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
"""
|
| 18 |
+
def __init__(self, config):
|
|
|
|
|
|
|
| 19 |
super().__init__(config)
|
| 20 |
self.config = config
|
| 21 |
|
| 22 |
self.embeddings = RobertaEmbeddings(config)
|
| 23 |
self.encoder = RobertaEncoder(config)
|
| 24 |
+
self.pooler = None
|
|
|
|
|
|
|
|
|
|
| 25 |
# Initialize weights and apply final processing
|
| 26 |
self.post_init()
|
| 27 |
|
|
|
|
| 31 |
attention_mask: Optional[torch.Tensor] = None,
|
| 32 |
token_type_ids: Optional[torch.Tensor] = None,
|
| 33 |
position_ids: Optional[torch.Tensor] = None,
|
| 34 |
+
per: Optional[torch.Tensor] = None,
|
| 35 |
+
org: Optional[torch.Tensor] = None,
|
| 36 |
+
loc: Optional[torch.Tensor] = None,
|
|
|
|
| 37 |
head_mask: Optional[torch.Tensor] = None,
|
| 38 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 39 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
flagged/log.csv
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
text,output,flag,username,timestamp
|
| 2 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:29:01.157209
|
| 3 |
-
Barack Obama navštívil Prahu minulý týden .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Barack Obama"", ""class_or_confidence"": ""OSV""}, {""token"": "" nav\u0161t\u00edvil "", ""class_or_confidence"": null}, {""token"": ""Prahu"", ""class_or_confidence"": ""LOC""}, {""token"": "" minul\u00fd t\u00fdden ."", ""class_or_confidence"": null}]",,,2024-05-06 02:31:57.950478
|
| 4 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:51:30.197653
|
| 5 |
-
Barack Obama navštívil Prahu minulý týden .,,,,2024-05-06 10:58:33.085992
|
| 6 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:17.762652
|
| 7 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:20.057269
|
| 8 |
-
,,,,,2024-05-09 22:59:12.114264
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -5,3 +5,4 @@ torch
|
|
| 5 |
simplemma
|
| 6 |
gradio
|
| 7 |
pandas
|
|
|
|
|
|
| 5 |
simplemma
|
| 6 |
gradio
|
| 7 |
pandas
|
| 8 |
+
name-datasets
|
style.css
CHANGED
|
@@ -6,10 +6,6 @@ footer {
|
|
| 6 |
color-scheme: light dark;
|
| 7 |
}
|
| 8 |
|
| 9 |
-
.container .svelte-ju12zg {
|
| 10 |
-
color: light-dark(black, white);
|
| 11 |
-
}
|
| 12 |
-
|
| 13 |
.text.svelte-ju12zg {
|
| 14 |
padding: 0;
|
| 15 |
margin: 0;
|
|
@@ -23,4 +19,9 @@ footer {
|
|
| 23 |
.textspan.svelte-ju12zg.no-cat {
|
| 24 |
margin: 0;
|
| 25 |
padding: 0;
|
| 26 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
color-scheme: light dark;
|
| 7 |
}
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
.text.svelte-ju12zg {
|
| 10 |
padding: 0;
|
| 11 |
margin: 0;
|
|
|
|
| 19 |
.textspan.svelte-ju12zg.no-cat {
|
| 20 |
margin: 0;
|
| 21 |
padding: 0;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
.category-label.svelte-ju12zg {
|
| 25 |
+
background-color: light-dark(white, black,);
|
| 26 |
+
|
| 27 |
+
}
|
upload_model.ipynb
CHANGED
|
@@ -2,13 +2,13 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
"data": {
|
| 10 |
"application/vnd.jupyter.widget-view+json": {
|
| 11 |
-
"model_id": "
|
| 12 |
"version_major": 2,
|
| 13 |
"version_minor": 0
|
| 14 |
},
|
|
@@ -28,7 +28,7 @@
|
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
-
"execution_count":
|
| 32 |
"metadata": {},
|
| 33 |
"outputs": [
|
| 34 |
{
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
"data": {
|
| 10 |
"application/vnd.jupyter.widget-view+json": {
|
| 11 |
+
"model_id": "556291d727474e0a82723d6459722b16",
|
| 12 |
"version_major": 2,
|
| 13 |
"version_minor": 0
|
| 14 |
},
|
|
|
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
+
"execution_count": 2,
|
| 32 |
"metadata": {},
|
| 33 |
"outputs": [
|
| 34 |
{
|
website_script.py
CHANGED
|
@@ -2,11 +2,39 @@ import json
|
|
| 2 |
import copy
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
|
| 7 |
-
from extended_embeddings.
|
| 8 |
-
from data_manipulation.dataset_funcions import
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def load():
|
|
@@ -18,7 +46,7 @@ def load():
|
|
| 18 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 19 |
model.eval()
|
| 20 |
|
| 21 |
-
gazetteers_for_matching =
|
| 22 |
temp = []
|
| 23 |
for i in gazetteers_for_matching.keys():
|
| 24 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
|
|
|
| 2 |
import copy
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from simplemma import lemmatize
|
| 6 |
from transformers import AutoTokenizer
|
| 7 |
|
| 8 |
+
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
|
| 9 |
+
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
|
| 10 |
+
|
| 11 |
+
# code originaly from data_manipulation.creation_gazetteers
|
| 12 |
+
def lemmatizing(x):
|
| 13 |
+
if x == "":
|
| 14 |
+
return ""
|
| 15 |
+
return lemmatize(x, lang="cs")
|
| 16 |
+
|
| 17 |
+
# code originaly from data_manipulation.creation_gazetteers
|
| 18 |
+
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
| 19 |
+
reverse_dictionary = {}
|
| 20 |
+
for key, values in dictionary.items():
|
| 21 |
+
for value in values:
|
| 22 |
+
reverse_dictionary[value] = key
|
| 23 |
+
if apply_lemmatizing:
|
| 24 |
+
temp = lemmatizing(value)
|
| 25 |
+
if temp != value:
|
| 26 |
+
reverse_dictionary[temp] = key
|
| 27 |
+
return reverse_dictionary
|
| 28 |
+
|
| 29 |
+
def load_json(path):
|
| 30 |
+
"""
|
| 31 |
+
Load gazetteers from a file
|
| 32 |
+
:param path: path to the gazetteer file
|
| 33 |
+
:return: a dict of gazetteers
|
| 34 |
+
"""
|
| 35 |
+
with open(path, 'r') as file:
|
| 36 |
+
data = json.load(file)
|
| 37 |
+
return data
|
| 38 |
|
| 39 |
|
| 40 |
def load():
|
|
|
|
| 46 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 47 |
model.eval()
|
| 48 |
|
| 49 |
+
gazetteers_for_matching = load_json(gazetteers_path)
|
| 50 |
temp = []
|
| 51 |
for i in gazetteers_for_matching.keys():
|
| 52 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|