diff --git a/main.py b/main.py index d5ac0b2b..45b5c7ac 100644 --- a/main.py +++ b/main.py @@ -94,6 +94,15 @@ def flatten_structure(self, structure): sections.extend(self.flatten_structure(content)) return sections + def reset_content(self, title): + self.contents[title] = "" + + def get_content(self, title): + try: + return self.contents[title] + except KeyError as e: + pass + def update_content(self, title, new_content): try: self.contents[title] += new_content @@ -253,6 +262,37 @@ def generate_section(transcript: str, existing_notes: str, section: str, model: statistics_to_return = GenerationStatistics(input_time=usage.prompt_time, output_time=usage.completion_time, input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=usage.total_time, model_name=model) yield statistics_to_return +def translate_content(text: str, language: str, model: str = "llama3-8b-8192"): + stream = st.session_state.groq.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": f"You are an expert Translator. Your sole objective is to Translate the input English TEXT into {language}. DO NOT PROVIDE ANYTHING ELSE EXCEPT THE TRANSLATION." + }, + { + "role": "user", + "content": f"### TEXT\n\n{text}\n\n### Translation to {language}:" + } + ], + temperature=0.3, + max_tokens=8000, + top_p=1, + stream=True, + stop=None, + ) + + for chunk in stream: + tokens = chunk.choices[0].delta.content + if tokens: + yield tokens + if x_groq := chunk.x_groq: + if not x_groq.usage: + continue + usage = x_groq.usage + statistics_to_return = GenerationStatistics(input_time=usage.prompt_time, output_time=usage.completion_time, input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=usage.total_time, model_name=model) + yield statistics_to_return + # Initialize if 'button_disabled' not in st.session_state: st.session_state.button_disabled = False @@ -322,11 +362,13 @@ def empty_st(): st.write(f"---") st.write("# Customization Settings\n🧪 These settings are experimental.\n") - st.write(f"By default, GroqNotes uses Llama3-70b for generating the notes outline and Llama3-8b for the content. This balances quality with speed and rate limit usage. You can customize these selections below.") + st.write(f"By default, GroqNotes uses Llama3-70b for generating the notes outline and Llama3-8b for the content and translation. This balances quality with speed and rate limit usage. You can customize these selections below.") outline_model_options = ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"] outline_selected_model = st.selectbox("Outline generation:", outline_model_options) content_model_options = ["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it", "gemma2-9b-it"] content_selected_model = st.selectbox("Content generation:", content_model_options) + target_language_options = ["English", "Spanish", "French"] + target_language_selected = st.selectbox("Target language:", target_language_options) # Add note about rate limits @@ -452,7 +494,7 @@ def display_statistics(): total_generation_statistics = GenerationStatistics(model_name=str(content_selected_model)) clear_status() - + print("Target language: ", target_language_selected) try: notes_structure_json = json.loads(notes_structure) notes = NoteSection(structure=notes_structure_json,transcript=transcription_text) @@ -476,6 +518,21 @@ def stream_section_content(sections): display_statistics() elif chunk is not None: st.session_state.notes.update_content(title, chunk) + if target_language_selected != "English": + translated_content_stream = translate_content(text=st.session_state.notes.get_content(title), language=target_language_selected, model=content_selected_model) + # Reset content to empty string for translation + st.session_state.notes.reset_content(title) + for chunk in translated_content_stream: + # Check if GenerationStatistics data is returned instead of str tokens + chunk_data = chunk + if type(chunk_data) == GenerationStatistics: + total_generation_statistics.add(chunk_data) + + st.session_state.statistics_text = str(total_generation_statistics) + display_statistics() + elif chunk is not None: + st.session_state.notes.update_content(title, chunk) + elif isinstance(content, dict): stream_section_content(content)