-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation_utils.py
72 lines (61 loc) · 2.49 KB
/
evaluation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pandas as pd
from IPython.core.display import display, HTML
def print_statistics(results, method):
"""
Calculate statistics
"""
tp = results[(results.label == 'negated') & (results[method] == 'negated')].shape[0]
tn = results[(results.label == 'not negated') & (results[method] == 'not negated')].shape[0]
fp = results[(results.label == 'not negated') & (results[method] == 'negated')].shape[0]
fn = results[(results.label == 'negated') & (results[method] == 'not negated')].shape[0]
recall = round(tp / (tp + fn), 2)
precision = round(tp / (tp + fp), 2)
specificity = round(tn / (tn + fp), 2)
accuracy = round((tp + tn) / (tp + fp + tn + fn), 2)
f1 = round((2*tp) / ((2*tp) + fp + fn), 2)
print(f'tp: {tp}')
print(f'tn: {tn}')
print(f'fp: {fp}')
print(f'fn: {fn}')
print(f'recall: {recall}')
print(f'precision: {precision}')
print(f'specificity: {specificity}')
print(f'accuracy: {accuracy}')
print(f'f1: {f1}')
def get_document_text(entity_id, dcc_dir, predictions=None, print_text=True, print_html=True, obfuscate_entity=False):
"""
Print and return a document from the DCC dataset based on entity ID
"""
entity_id_split = entity_id.split('_')
document_name = entity_id_split[0]
start = int(entity_id_split[1])
end = int(entity_id_split[2])
document_type = document_name[0:2]
# Print text
text_path = dcc_dir / document_type / f'{document_name}.txt'
with open(text_path, 'r') as text_file:
text = text_file.read()
def pretty_print(txt, start, end):
"""
Print a string in html, with part of it highlighted
"""
snippet = txt[start:end]
def highlight(snippet):
blob = f"<text>{snippet}</text>"
blob = f"<mark style='background-color: #fff59d'>{blob}</mark>"
return blob
display(HTML((''.join((txt[:start], highlight(snippet), txt[end:], '<br>')))))
if print_text:
if print_html:
pretty_print(text, start, end)
else:
print(text)
print(f'Entity: {text[start: end]} ({start}-{end})\n')
# Print result
if predictions is not None:
print(predictions[predictions.entity_id == entity_id])
if obfuscate_entity:
# replace the entity with '[ENT]'
text = text[:start] + '[ENT]' + text[end:]
# Also return text, start and stop for downstream analysis
return text, start, end