-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathRock_fact_checker.py
137 lines (122 loc) · 5.08 KB
/
Rock_fact_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import random
import time
import logging
from json import JSONDecodeError
import streamlit as st
from app_utils.backend_utils import load_statements, check_statement, explain_using_llm
from app_utils.frontend_utils import (
set_state_if_absent,
reset_results,
entailment_html_messages,
create_df_for_relevant_snippets,
create_ternary_plot,
build_sidebar,
)
from app_utils.config import RETRIEVER_TOP_K
def main():
statements = load_statements()
build_sidebar()
# Persistent state
set_state_if_absent("statement", "Elvis Presley is alive")
set_state_if_absent("answer", "")
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
set_state_if_absent("random_statement_requested", False)
st.write("# Fact Checking 🎸 Rocks!")
st.write()
st.markdown(
"""
##### Enter a factual statement about [Rock music](https://en.wikipedia.org/wiki/List_of_mainstream_rock_performers) and let the AI check it out for you...
"""
)
# Search bar
statement = st.text_input(
"", value=st.session_state.statement, max_chars=100, on_change=reset_results
)
col1, col2 = st.columns(2)
col1.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
col2.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
# Run button
run_pressed = col1.button("Run")
# Random statement button
if col2.button("Random statement"):
reset_results()
statement = random.choice(statements)
# Avoid picking the same statement twice (the change is not visible on the UI)
while statement == st.session_state.statement:
statement = random.choice(statements)
st.session_state.statement = statement
st.session_state.random_statement_requested = True
# Re-runs the script setting the random statement as the textbox value
# Unfortunately necessary as the Random statement button is _below_ the textbox
# Adapted for Streamlit>=1.12.0
if hasattr(st, "scriptrunner"):
raise st.scriptrunner.script_runner.RerunException(
st.scriptrunner.script_requests.RerunData(widget_states=None)
)
raise st.runtime.scriptrunner.script_runner.RerunException(
st.runtime.scriptrunner.script_requests.RerunData(widget_states=None)
)
else:
st.session_state.random_statement_requested = False
run_query = (
run_pressed or statement != st.session_state.statement
) and not st.session_state.random_statement_requested
# Get results for query
if run_query and statement:
time_start = time.time()
reset_results()
st.session_state.statement = statement
with st.spinner("🧠 Performing neural search on documents..."):
try:
st.session_state.results = check_statement(statement, RETRIEVER_TOP_K)
print(f"S: {statement}")
time_end = time.time()
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
print(f"elapsed time: {time_end - time_start}")
except JSONDecodeError as je:
st.error(
"👓 An error occurred reading the results. Is the document store working?"
)
return
except Exception as e:
logging.exception(e)
st.error("🐞 An error occurred during the request.")
return
# Display results
if st.session_state.results:
docs = st.session_state.results["documents"]
agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
# show different messages depending on entailment results
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
message = entailment_html_messages[max_key]
st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True)
st.markdown(f"###### Aggregate entailment information:")
col1, col2 = st.columns([2, 1])
fig = create_ternary_plot(agg_entailment_info)
with col1:
# theme=None helps to preserve default plotly colors
st.plotly_chart(fig, use_container_width=True, theme=None)
with col2:
st.write(agg_entailment_info)
st.markdown(f"###### Most Relevant snippets:")
df, urls = create_df_for_relevant_snippets(docs)
st.dataframe(df)
str_wiki_pages = "Wikipedia source pages: "
for doc, url in urls.items():
str_wiki_pages += f"[{doc}]({url}) "
st.markdown(str_wiki_pages)
if max_key != "neutral":
st.markdown("#### Why ❓ *(experimental)*")
if st.button("Explain using a Large Language Model 🤖..."):
explanation = explain_using_llm(
statement=statement,
documents=docs,
entailment_or_contradiction=max_key,
)
st.markdown(explanation)
main()