diff --git a/python/api-examples-source/tutorials/chat/requirements.txt b/python/api-examples-source/tutorials/chat/requirements.txt new file mode 100644 index 000000000..0db1a7d5c --- /dev/null +++ b/python/api-examples-source/tutorials/chat/requirements.txt @@ -0,0 +1 @@ +lorem \ No newline at end of file diff --git a/python/api-examples-source/tutorials/chat/tutorial-chat-feedback.py b/python/api-examples-source/tutorials/chat/tutorial-chat-feedback.py new file mode 100644 index 000000000..99205490e --- /dev/null +++ b/python/api-examples-source/tutorials/chat/tutorial-chat-feedback.py @@ -0,0 +1,45 @@ +import streamlit as st +import time + + +def chat_stream(prompt): + response = f'You said, "{prompt}" ...interesting.' + for char in response: + yield char + time.sleep(0.02) + + +def save_feedback(index): + st.session_state.history[index]["feedback"] = st.session_state[f"feedback_{index}"] + + +if "history" not in st.session_state: + st.session_state.history = [] + +for i, message in enumerate(st.session_state.history): + with st.chat_message(message["role"]): + st.write(message["content"]) + if message["role"] == "assistant": + feedback = message.get("feedback", None) + st.session_state[f"feedback_{i}"] = feedback + st.feedback( + "thumbs", + key=f"feedback_{i}", + disabled=feedback is not None, + on_change=save_feedback, + args=[i], + ) + +if prompt := st.chat_input("Say something"): + with st.chat_message("user"): + st.write(prompt) + st.session_state.history.append({"role": "user", "content": prompt}) + with st.chat_message("assistant"): + response = st.write_stream(chat_stream(prompt)) + st.feedback( + "thumbs", + key=f"feedback_{len(st.session_state.history)}", + on_change=save_feedback, + args=[len(st.session_state.history)], + ) + st.session_state.history.append({"role": "assistant", "content": response}) diff --git a/python/api-examples-source/tutorials/chat/tutorial-chat-revision.py b/python/api-examples-source/tutorials/chat/tutorial-chat-revision.py new file mode 100644 index 000000000..92c3daa6e --- /dev/null +++ b/python/api-examples-source/tutorials/chat/tutorial-chat-revision.py @@ -0,0 +1,147 @@ +import streamlit as st +import lorem +from random import randint +import time + +if "stage" not in st.session_state: + st.session_state.stage = "user" + st.session_state.history = [] + st.session_state.pending = None + st.session_state.validation = {} + + +def chat_stream(): + for i in range(randint(3, 9)): + yield lorem.sentence() + " " + time.sleep(0.2) + + +def validate(response): + response_sentences = response.split(". ") + response_sentences = [ + sentence.strip(". ") + "." + for sentence in response_sentences + if sentence.strip(". ") != "" + ] + validation_list = [ + True if sentence.count(" ") > 4 else False for sentence in response_sentences + ] + return response_sentences, validation_list + + +def add_highlights(response_sentences, validation_list, bg="red", text="red"): + return [ + f":{text}[:{bg}-background[" + sentence + "]]" if not is_valid else sentence + for sentence, is_valid in zip(response_sentences, validation_list) + ] + + +for message in st.session_state.history: + with st.chat_message(message["role"]): + st.write(message["content"]) + +if st.session_state.stage == "user": + if user_input := st.chat_input("Enter a prompt"): + st.session_state.history.append({"role": "user", "content": user_input}) + with st.chat_message("user"): + st.write(user_input) + with st.chat_message("assistant"): + response = st.write_stream(chat_stream()) + st.session_state.pending = response + st.session_state.stage = "validate" + st.rerun() + +elif st.session_state.stage == "validate": + st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) + response_sentences, validation_list = validate(st.session_state.pending) + highlighted_sentences = add_highlights(response_sentences, validation_list) + with st.chat_message("assistant"): + st.markdown(" ".join(highlighted_sentences)) + st.divider() + cols = st.columns(3) + if cols[0].button( + "Correct errors", type="primary", disabled=all(validation_list) + ): + st.session_state.validation = { + "sentences": response_sentences, + "valid": validation_list, + } + st.session_state.stage = "correct" + st.rerun() + if cols[1].button("Accept"): + st.session_state.history.append( + {"role": "assistant", "content": st.session_state.pending} + ) + st.session_state.pending = None + st.session_state.validation = {} + st.session_state.stage = "user" + st.rerun() + if cols[2].button("Rewrite answer", type="tertiary"): + st.session_state.stage = "rewrite" + st.rerun() + +elif st.session_state.stage == "correct": + st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) + response_sentences = st.session_state.validation["sentences"] + validation_list = st.session_state.validation["valid"] + highlighted_sentences = add_highlights( + response_sentences, validation_list, "gray", "gray" + ) + if not all(validation_list): + focus = validation_list.index(False) + highlighted_sentences[focus] = ":red[:red" + highlighted_sentences[focus][11:] + else: + focus = None + with st.chat_message("assistant"): + st.markdown(" ".join(highlighted_sentences)) + st.divider() + if focus is not None: + new_sentence = st.text_input( + "Replacement text:", value=response_sentences[focus] + ) + cols = st.columns(2) + if cols[0].button( + "Update", type="primary", disabled=len(new_sentence.strip()) < 1 + ): + st.session_state.validation["sentences"][focus] = ( + new_sentence.strip(". ") + "." + ) + st.session_state.validation["valid"][focus] = True + st.session_state.pending = " ".join( + st.session_state.validation["sentences"] + ) + st.rerun() + if cols[1].button("Remove"): + st.session_state.validation["sentences"].pop(focus) + st.session_state.validation["valid"].pop(focus) + st.session_state.pending = " ".join( + st.session_state.validation["sentences"] + ) + st.rerun() + else: + cols = st.columns(2) + if cols[0].button("Accept", type="primary"): + st.session_state.history.append( + {"role": "assistant", "content": st.session_state.pending} + ) + st.session_state.pending = None + st.session_state.validation = {} + st.session_state.stage = "user" + st.rerun() + if cols[1].button("Re-validate"): + st.session_state.validation = {} + st.session_state.stage = "validate" + st.rerun() + +elif st.session_state.stage == "rewrite": + st.chat_input("Accept, correct, or rewrite the answer above.", disabled=True) + with st.chat_message("assistant"): + new = st.text_area("Rewrite the answer", value=st.session_state.pending) + if st.button( + "Update", type="primary", disabled=new is None or new.strip(". ") == "" + ): + st.session_state.history.append({"role": "assistant", "content": new}) + st.session_state.pending = None + st.session_state.validation = {} + st.session_state.stage = "user" + st.rerun()