Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add translation support for target language selection in sidebar #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def flatten_structure(self, structure):
sections.extend(self.flatten_structure(content))
return sections

def reset_content(self, title):
self.contents[title] = ""

def get_content(self, title):
try:
return self.contents[title]
except KeyError as e:
pass

def update_content(self, title, new_content):
try:
self.contents[title] += new_content
Expand Down Expand Up @@ -253,6 +262,37 @@ def generate_section(transcript: str, existing_notes: str, section: str, model:
statistics_to_return = GenerationStatistics(input_time=usage.prompt_time, output_time=usage.completion_time, input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=usage.total_time, model_name=model)
yield statistics_to_return

def translate_content(text: str, language: str, model: str = "llama3-8b-8192"):
stream = st.session_state.groq.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": f"You are an expert Translator. Your sole objective is to Translate the input English TEXT into {language}. DO NOT PROVIDE ANYTHING ELSE EXCEPT THE TRANSLATION."
},
{
"role": "user",
"content": f"### TEXT\n\n{text}\n\n### Translation to {language}:"
}
],
temperature=0.3,
max_tokens=8000,
top_p=1,
stream=True,
stop=None,
)

for chunk in stream:
tokens = chunk.choices[0].delta.content
if tokens:
yield tokens
if x_groq := chunk.x_groq:
if not x_groq.usage:
continue
usage = x_groq.usage
statistics_to_return = GenerationStatistics(input_time=usage.prompt_time, output_time=usage.completion_time, input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=usage.total_time, model_name=model)
yield statistics_to_return

# Initialize
if 'button_disabled' not in st.session_state:
st.session_state.button_disabled = False
Expand Down Expand Up @@ -322,11 +362,13 @@ def empty_st():
st.write(f"---")

st.write("# Customization Settings\n🧪 These settings are experimental.\n")
st.write(f"By default, GroqNotes uses Llama3-70b for generating the notes outline and Llama3-8b for the content. This balances quality with speed and rate limit usage. You can customize these selections below.")
st.write(f"By default, GroqNotes uses Llama3-70b for generating the notes outline and Llama3-8b for the content and translation. This balances quality with speed and rate limit usage. You can customize these selections below.")
outline_model_options = ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
outline_selected_model = st.selectbox("Outline generation:", outline_model_options)
content_model_options = ["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it", "gemma2-9b-it"]
content_selected_model = st.selectbox("Content generation:", content_model_options)
target_language_options = ["English", "Spanish", "French"]
target_language_selected = st.selectbox("Target language:", target_language_options)


# Add note about rate limits
Expand Down Expand Up @@ -452,7 +494,7 @@ def display_statistics():
total_generation_statistics = GenerationStatistics(model_name=str(content_selected_model))
clear_status()


print("Target language: ", target_language_selected)
try:
notes_structure_json = json.loads(notes_structure)
notes = NoteSection(structure=notes_structure_json,transcript=transcription_text)
Expand All @@ -476,6 +518,21 @@ def stream_section_content(sections):
display_statistics()
elif chunk is not None:
st.session_state.notes.update_content(title, chunk)
if target_language_selected != "English":
translated_content_stream = translate_content(text=st.session_state.notes.get_content(title), language=target_language_selected, model=content_selected_model)
# Reset content to empty string for translation
st.session_state.notes.reset_content(title)
for chunk in translated_content_stream:
# Check if GenerationStatistics data is returned instead of str tokens
chunk_data = chunk
if type(chunk_data) == GenerationStatistics:
total_generation_statistics.add(chunk_data)

st.session_state.statistics_text = str(total_generation_statistics)
display_statistics()
elif chunk is not None:
st.session_state.notes.update_content(title, chunk)

elif isinstance(content, dict):
stream_section_content(content)

Expand Down