diff --git a/Cargo.toml b/Cargo.toml index edece14..c89c54f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ license-file = "LICENSE" homepage = "https://github.com/groovybits/rsllm/wiki" repository = "https://github.com/groovybits/rsllm" authors = ["Chris Kennedy"] -version = "0.4.0" +version = "0.4.1" edition = "2021" [lib] diff --git a/scripts/alice.sh b/scripts/alice.sh index b2d721f..06eff6e 100755 --- a/scripts/alice.sh +++ b/scripts/alice.sh @@ -8,9 +8,16 @@ # # -BUILD_TYPE=debug +BUILD_TYPE=release MODEL=mistral MODEL_ID=7b-it +MAX_TOKENS=1000 +ALIGNMENT=right +TEMPERATURE=0.8 +POLL_INTERVAL=10 +IMAGE_CONCURRENCY=1 +SPEECH_CONCURRENCY=1 +CONTEXT_SIZE=3000 DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \ RUST_BACKTRACE=full target/${BUILD_TYPE}/rsllm \ --query "create a story based on an anime About Alice an adult twitch streaming girl who lives in AI Wonderland. Have it vary off the title 'Alice in AI Wonderland' with a random plotline you create based on classic anime characters appearing in the wonderland. Alices AI Wonderland is a happy fun show where Alice goes through experiences similar to Alice in Wonderland where she grows small or large depending one what she eats. Add in AI technology twists. Have it fully formatted like a transcript with the character speaking parts mostly speaking in first person, minimal narration. create a whole episode full length with classic anime characters with Alice the main character of AI Wonderland." \ @@ -22,10 +29,11 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \ --ndi-images \ --mimic3-tts \ --model-id $MODEL_ID \ - --image-alignment right \ - --temperature 0.8 \ - --image-concurrency 1 \ - --speech-concurrency 1\ - --max-concurrent-sd-image-tasks 8 \ + --image-alignment $ALIGNMENT \ + --temperature $TEMPERATURE \ + --image-concurrency $IMAGE_CONCURRENCY \ + --speech-concurrency $SPEECH_CONCURRENCY \ + --poll-interval $POLL_INTERVAL \ + --llm-history-size $CONTEXT_SIZE \ --daemon \ - --max-tokens 1200 $@ + --max-tokens $MAX_TOKENS $@ diff --git a/src/args.rs b/src/args.rs index b1ce036..eed4f1f 100644 --- a/src/args.rs +++ b/src/args.rs @@ -4,7 +4,7 @@ use clap::Parser; #[derive(Parser, Debug, Clone)] #[clap( author = "Chris Kennedy", - version = "0.4.0", + version = "0.4.1", about = "Rust LLM Stream Analyzer and Content Generator" )] pub struct Args { @@ -17,7 +17,7 @@ pub struct Args { )] pub system_prompt: String, - /// System prompt + /// Prompt #[clap( long, env = "QUERY", @@ -596,6 +596,15 @@ pub struct Args { )] pub image_alignment: String, + /// Subtitles - enable subtitles + #[clap( + long, + env = "SUBTITLES", + default_value_t = false, + help = "Subtitles - enable subtitles." + )] + pub subtitles: bool, + /// Subtitle position - top, mid-top, center, mid-bottom, bottom - bottom is default #[clap( long, diff --git a/src/main.rs b/src/main.rs index 8e8dc2e..3366bc3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -95,61 +95,72 @@ async fn main() { let image_sem = Arc::new(Semaphore::new(args.image_concurrency)); let speech_sem = Arc::new(Semaphore::new(args.speech_concurrency)); + let ndi_sem = Arc::new(Semaphore::new(1)); // Image processing task + let running_processor_images = Arc::new(AtomicBool::new(true)); let image_processing_task = { let image_sem = Arc::clone(&image_sem); let processed_data_store = processed_data_store.clone(); + let running_processor_clone = running_processor_images.clone(); tokio::spawn(async move { - while let Some(message_data) = image_task_receiver.recv().await { - let images = process_image(message_data.clone(), Arc::clone(&image_sem)).await; - let mut store = processed_data_store.lock().await; - - match store.entry(message_data.paragraph_count) { - std::collections::hash_map::Entry::Vacant(e) => { - e.insert(ProcessedData { - paragraph: message_data.paragraph.clone(), - image_data: Some(images), - audio_data: None, - paragraph_count: message_data.paragraph_count, - subtitle_position: message_data.subtitle_position.clone(), - time_stamp: 0, - }); - } - std::collections::hash_map::Entry::Occupied(mut e) => { - let entry = e.get_mut(); - entry.image_data = Some(images); + while running_processor_clone.load(Ordering::SeqCst) { + while let Some(message_data) = image_task_receiver.recv().await { + let images = process_image(message_data.clone(), Arc::clone(&image_sem)).await; + let mut store = processed_data_store.lock().await; + + match store.entry(message_data.paragraph_count) { + std::collections::hash_map::Entry::Vacant(e) => { + e.insert(ProcessedData { + paragraph: message_data.paragraph.clone(), + image_data: Some(images), + audio_data: None, + paragraph_count: message_data.paragraph_count, + subtitle_position: message_data.subtitle_position.clone(), + time_stamp: 0, + }); + } + std::collections::hash_map::Entry::Occupied(mut e) => { + let entry = e.get_mut(); + entry.image_data = Some(images); + } } + break; } } }) }; // Speech processing task + let running_processor_speech = Arc::new(AtomicBool::new(true)); let speech_processing_task = { let speech_sem = Arc::clone(&speech_sem); let processed_data_store = processed_data_store.clone(); + let running_processor_clone = running_processor_speech.clone(); tokio::spawn(async move { - while let Some(message_data) = speech_task_receiver.recv().await { - let speech_data = - process_speech(message_data.clone(), Arc::clone(&speech_sem)).await; - let mut store = processed_data_store.lock().await; - - match store.entry(message_data.paragraph_count) { - std::collections::hash_map::Entry::Vacant(e) => { - e.insert(ProcessedData { - paragraph: message_data.paragraph.clone(), - image_data: None, - audio_data: Some(speech_data), - paragraph_count: message_data.paragraph_count, - subtitle_position: message_data.subtitle_position.clone(), - time_stamp: 0, - }); - } - std::collections::hash_map::Entry::Occupied(mut e) => { - let entry = e.get_mut(); - entry.audio_data = Some(speech_data); + while running_processor_clone.load(Ordering::SeqCst) { + while let Some(message_data) = speech_task_receiver.recv().await { + let speech_data = + process_speech(message_data.clone(), Arc::clone(&speech_sem)).await; + let mut store = processed_data_store.lock().await; + + match store.entry(message_data.paragraph_count) { + std::collections::hash_map::Entry::Vacant(e) => { + e.insert(ProcessedData { + paragraph: message_data.paragraph.clone(), + image_data: None, + audio_data: Some(speech_data), + paragraph_count: message_data.paragraph_count, + subtitle_position: message_data.subtitle_position.clone(), + time_stamp: 0, + }); + } + std::collections::hash_map::Entry::Occupied(mut e) => { + let entry = e.get_mut(); + entry.audio_data = Some(speech_data); + } } + break; } } }) @@ -159,8 +170,11 @@ async fn main() { let processed_data_store_for_ndi = processed_data_store.clone(); let args_for_ndi = args.clone(); + let running_processor_ndi = Arc::new(AtomicBool::new(true)); + let running_processor_clone = running_processor_ndi.clone(); let ndi_sync_task = tokio::spawn(async move { - loop { + let ndi_sem = Arc::clone(&ndi_sem); + while running_processor_clone.load(Ordering::SeqCst) { // Artificial delay for demonstration; adjust as needed tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -180,7 +194,7 @@ async fn main() { for key in keys_to_remove { if let Some(data) = processed_data_store_for_ndi.lock().await.remove(&key) { - send_to_ndi(data, &args_for_ndi).await; + send_to_ndi(data, &args_for_ndi, Arc::clone(&ndi_sem)).await; } } } @@ -422,6 +436,7 @@ async fn main() { // empty decode_batch decode_batch.clear(); } + break; } } else { // sleep for a while to avoid busy loop @@ -434,28 +449,34 @@ async fn main() { .ok() .unwrap_or_else(|| "NO_AUTH_KEY".to_string()); + let running_processor_twitch = Arc::new(AtomicBool::new(true)); if args.twitch_client { - if twitch_auth == "NO_AUTH_KEY" { - error!("Twitch Auth key is not set. Please set the TWITCH_AUTH environment variable."); - std::process::exit(1); - } - // Clone values before moving them into the closure let twitch_channel_clone = vec![args.twitch_channel.clone()]; let twitch_username_clone = args.twitch_username.clone(); let twitch_auth_clone = twitch_auth.clone(); // Assuming twitch_auth is clonable and you want to use it within the closure. - info!( - "Setting up Twitch channel {} for user {}", - twitch_channel_clone.join(", "), // Assuming it's a Vec - twitch_username_clone - ); - + // TODO: add mpsc channels for communication between the twitch setup and the main thread + let running_processor_clone = running_processor_twitch.clone(); let _twitch_handle = tokio::spawn(async move { + info!( + "Setting up Twitch channel {} for user {}", + twitch_channel_clone.join(", "), // Assuming it's a Vec + twitch_username_clone + ); + + if twitch_auth == "NO_AUTH_KEY" { + error!( + "Twitch Auth key is not set. Please set the TWITCH_AUTH environment variable." + ); + std::process::exit(1); + } + match twitch_setup( twitch_username_clone.clone(), twitch_auth_clone, twitch_channel_clone.clone(), + running_processor_clone, ) .await { @@ -477,10 +498,11 @@ async fn main() { } }); - // Wait for the twitch setup to complete - //if let Err(e) = twitch_handle.await { - // error!("Error setting up Twitch channel: {}", e); - //} + //TODO: put this at the end. + // wait for the running_processor to be set to false + /*if let Err(e) = twitch_handle.await { + error!("Error waiting for Twitch channel: {}", e); + }*/ } let poll_interval = args.poll_interval; let poll_interval_duration = Duration::from_millis(poll_interval); @@ -735,11 +757,6 @@ async fn main() { let mut current_paragraph: Vec = Vec::new(); let mut paragraph_count = 0; let min_paragraph_len = args.sd_text_min; // Minimum length of a paragraph to generate an image - let mut image_spawn_handles = Vec::new(); - - // Stable Diffusion number of tasks max - // Before starting loop, initialize the semaphore with a specific number of permits - let semaphore_sd_image = Arc::new(Semaphore::new(args.max_concurrent_sd_image_tasks)); // create uuid unique identifier for the output images let output_id = Uuid::new_v4().simple().to_string(); // Generates a UUID and converts it to a simple, hyphen-free string @@ -829,7 +846,6 @@ async fn main() { // Clone necessary data for use in the async block let paragraph_clone = paragraphs[paragraph_count].clone(); let output_id_clone = output_id.clone(); - let sem_clone_sd_image = semaphore_sd_image.clone(); let mimic3_voice = args.mimic3_voice.clone().to_string(); let image_alignment = args.image_alignment.clone(); let subtitle_position = args.subtitle_position.clone(); @@ -838,75 +854,58 @@ async fn main() { let image_task_sender_clone = image_task_sender.clone(); let speech_task_sender_clone = speech_task_sender.clone(); - let handle = tokio::spawn(async move { - // Declare the permit variable outside the if block to extend its scope - let _permit = if args.sd_image - || (args.mimic3_tts || args.oai_tts || args.tts_enable) - { - // Conditionally acquire the permit and store it in an Option - Some(sem_clone_sd_image.acquire().await.expect( - "Stable Diffusion: Failed to acquire semaphore permit", - )) - } else { - // If the condition is not met, no permit is acquired, and None is stored - None - }; + let mut sd_config = SDConfig::new(); + sd_config.prompt = paragraph_clone; + sd_config.height = Some(args.sd_height); + sd_config.width = Some(args.sd_width); + sd_config.image_position = Some(image_alignment); + if args.sd_scaled_height > 0 { + sd_config.scaled_height = Some(args.sd_scaled_height); + } + if args.sd_scaled_width > 0 { + sd_config.scaled_width = Some(args.sd_scaled_width); + } - let mut sd_config = SDConfig::new(); - sd_config.prompt = paragraph_clone; - sd_config.height = Some(args.sd_height); - sd_config.width = Some(args.sd_width); - sd_config.image_position = Some(image_alignment); - if args.sd_scaled_height > 0 { - sd_config.scaled_height = Some(args.sd_scaled_height); - } - if args.sd_scaled_width > 0 { - sd_config.scaled_width = Some(args.sd_scaled_width); - } + let args_clone = args.clone(); + let mimic3_voice_clone = mimic3_voice.clone(); + let subtitle_position_clone = subtitle_position.clone(); - let args_clone = args.clone(); - let mimic3_voice_clone = mimic3_voice.clone(); - let subtitle_position_clone = subtitle_position.clone(); - - if args.sd_image { - debug!("Generating images with prompt: {}", sd_config.prompt); - - // Create MessageData for image task - let message_data_for_image = MessageData { - paragraph: sd_config.prompt.clone(), // Clone for the image task - output_id: output_id_clone.clone(), - paragraph_count, - sd_config: sd_config.clone(), // Assuming SDConfig is set up previously and is cloneable - mimic3_voice: mimic3_voice_clone.clone(), - subtitle_position: subtitle_position_clone.clone(), - args: args_clone.clone(), - }; - - // For image tasks - image_task_sender_clone - .send(message_data_for_image) - .await - .expect("Failed to send image task"); - } + if args.sd_image { + debug!("Generating images with prompt: {}", sd_config.prompt); - let message_data_for_speech = MessageData { - paragraph: sd_config.prompt.clone(), // Already cloned above - output_id: output_id_clone.clone(), // Clone again for speech task + // Create MessageData for image task + let message_data_for_image = MessageData { + paragraph: sd_config.prompt.clone(), // Clone for the image task + output_id: output_id_clone.clone(), paragraph_count, - sd_config: sd_config.clone(), // Clone again if necessary - mimic3_voice: mimic3_voice_clone, // Already cloned above - subtitle_position: subtitle_position_clone, // Already cloned above - args: args_clone, // Already cloned above + sd_config: sd_config.clone(), // Assuming SDConfig is set up previously and is cloneable + mimic3_voice: mimic3_voice_clone.clone(), + subtitle_position: subtitle_position_clone.clone(), + args: args_clone.clone(), }; - // For speech tasks - speech_task_sender_clone - .send(message_data_for_speech) + // For image tasks + image_task_sender_clone + .send(message_data_for_image) .await - .expect("Failed to send speech task"); - }); + .expect("Failed to send image task"); + } + + let message_data_for_speech = MessageData { + paragraph: sd_config.prompt.clone(), // Already cloned above + output_id: output_id_clone.clone(), // Clone again for speech task + paragraph_count, + sd_config: sd_config.clone(), // Clone again if necessary + mimic3_voice: mimic3_voice_clone, // Already cloned above + subtitle_position: subtitle_position_clone, // Already cloned above + args: args_clone, // Already cloned above + }; - image_spawn_handles.push(handle); + // For speech tasks + speech_task_sender_clone + .send(message_data_for_speech) + .await + .expect("Failed to send speech task"); } // ** End of TTS and Image Generation ** @@ -944,7 +943,6 @@ async fn main() { let paragraph_text = current_paragraph.join(""); // Join without spaces as indicated let paragraph_clone = paragraph_text.clone(); let output_id_clone = output_id.clone(); - let sem_clone_sd_image = semaphore_sd_image.clone(); let mimic3_voice = args.mimic3_voice.clone().to_string(); let image_alignment = args.image_alignment.clone(); let subtitle_position = args.subtitle_position.clone(); @@ -953,77 +951,58 @@ async fn main() { let image_task_sender_clone = image_task_sender.clone(); let speech_task_sender_clone = speech_task_sender.clone(); - let handle = - tokio::spawn(async move { - // Declare the permit variable outside the if block to extend its scope - let _permit = - if args.sd_image - || (args.mimic3_tts || args.oai_tts || args.tts_enable) - { - // Conditionally acquire the permit and store it in an Option - Some(sem_clone_sd_image.acquire().await.expect( - "Stable Diffusion: Failed to acquire semaphore permit", - )) - } else { - // If the condition is not met, no permit is acquired, and None is stored - None - }; - - let mut sd_config = SDConfig::new(); - sd_config.prompt = paragraph_clone; - sd_config.height = Some(args.sd_height); - sd_config.width = Some(args.sd_width); - sd_config.image_position = Some(image_alignment); - if args.sd_scaled_height > 0 { - sd_config.scaled_height = Some(args.sd_scaled_height); - } - if args.sd_scaled_width > 0 { - sd_config.scaled_width = Some(args.sd_scaled_width); - } - - let args_clone = args.clone(); - let mimic3_voice_clone = mimic3_voice.clone(); - let subtitle_position_clone = subtitle_position.clone(); - - if args.sd_image { - debug!("Generating images with prompt: {}", sd_config.prompt); - - // Create MessageData for image task - let message_data_for_image = MessageData { - paragraph: sd_config.prompt.clone(), // Clone for the image task - output_id: output_id_clone.clone(), - paragraph_count, - sd_config: sd_config.clone(), // Assuming SDConfig is set up previously and is cloneable - mimic3_voice: mimic3_voice_clone.clone(), - subtitle_position: subtitle_position_clone.clone(), - args: args_clone.clone(), - }; - - // For image tasks - image_task_sender_clone - .send(message_data_for_image) - .await - .expect("Failed to send image task"); - } + let mut sd_config = SDConfig::new(); + sd_config.prompt = paragraph_clone; + sd_config.height = Some(args.sd_height); + sd_config.width = Some(args.sd_width); + sd_config.image_position = Some(image_alignment); + if args.sd_scaled_height > 0 { + sd_config.scaled_height = Some(args.sd_scaled_height); + } + if args.sd_scaled_width > 0 { + sd_config.scaled_width = Some(args.sd_scaled_width); + } - let message_data_for_speech = MessageData { - paragraph: sd_config.prompt.clone(), // Already cloned above - output_id: output_id_clone.clone(), // Clone again for speech task - paragraph_count, - sd_config: sd_config.clone(), // Clone again if necessary - mimic3_voice: mimic3_voice_clone, // Already cloned above - subtitle_position: subtitle_position_clone, // Already cloned above - args: args_clone, // Already cloned above - }; + let args_clone = args.clone(); + let mimic3_voice_clone = mimic3_voice.clone(); + let subtitle_position_clone = subtitle_position.clone(); + + if args.sd_image { + debug!("Generating images with prompt: {}", sd_config.prompt); + + // Create MessageData for image task + let message_data_for_image = MessageData { + paragraph: sd_config.prompt.clone(), // Clone for the image task + output_id: output_id_clone.clone(), + paragraph_count, + sd_config: sd_config.clone(), // Assuming SDConfig is set up previously and is cloneable + mimic3_voice: mimic3_voice_clone.clone(), + subtitle_position: subtitle_position_clone.clone(), + args: args_clone.clone(), + }; + + // For image tasks + image_task_sender_clone + .send(message_data_for_image) + .await + .expect("Failed to send image task"); + } - // For speech tasks - speech_task_sender_clone - .send(message_data_for_speech) - .await - .expect("Failed to send speech task"); - }); + let message_data_for_speech = MessageData { + paragraph: sd_config.prompt.clone(), // Already cloned above + output_id: output_id_clone.clone(), // Clone again for speech task + paragraph_count, + sd_config: sd_config.clone(), // Clone again if necessary + mimic3_voice: mimic3_voice_clone, // Already cloned above + subtitle_position: subtitle_position_clone, // Already cloned above + args: args_clone, // Already cloned above + }; - image_spawn_handles.push(handle); + // For speech tasks + speech_task_sender_clone + .send(message_data_for_speech) + .await + .expect("Failed to send speech task"); } // ** End of TTS and Image Generation ** @@ -1053,11 +1032,6 @@ async fn main() { role: "assistant".to_string(), content: answers_str.clone(), }); - - // wait for the image generation to finish - for handle in image_spawn_handles { - handle.await.unwrap(); - } } else { // Stream API Completion let stream = !args.no_stream; @@ -1113,11 +1087,13 @@ async fn main() { // stop the running threads running_processor.store(false, Ordering::SeqCst); + running_processor_images.store(false, Ordering::SeqCst); + running_processor_speech.store(false, Ordering::SeqCst); + running_processor_ndi.store(false, Ordering::SeqCst); + running_processor_twitch.store(false, Ordering::SeqCst); // Await the completion of background tasks let _ = processing_handle.await; - - // wait for the image speech and ndi tasks to finish let _ = image_processing_task.await; let _ = speech_processing_task.await; let _ = ndi_sync_task.await; diff --git a/src/pipeline.rs b/src/pipeline.rs index 852f4c1..e6899b0 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -148,7 +148,19 @@ pub struct ProcessedData { } // Function to send audio/video pairs to NDI -pub async fn send_to_ndi(processed_data: ProcessedData, args: &Args) { +pub async fn send_to_ndi(processed_data: ProcessedData, args: &Args, ndi_sem: Arc) { + let _permit = ndi_sem + .acquire() + .await + .expect("Failed to acquire ndi semaphore permit"); + + // check if args.subtitles is true, if so defined the processed_data.paragraph as a variable, if not have it be an empty string + let subtitle = if args.subtitles { + processed_data.paragraph + } else { + String::new() + }; + if let Some(image_data) = processed_data.image_data { if args.ndi_images { #[cfg(feature = "ndi")] @@ -156,7 +168,7 @@ pub async fn send_to_ndi(processed_data: ProcessedData, args: &Args) { debug!("Sending images over NDI"); send_images_over_ndi( image_data, - &processed_data.paragraph, + &subtitle, args.hardsub_font_size, &processed_data.subtitle_position, ) diff --git a/src/twitch_client.rs b/src/twitch_client.rs index a63164a..2f1f18e 100644 --- a/src/twitch_client.rs +++ b/src/twitch_client.rs @@ -1,7 +1,14 @@ use anyhow::Result; use std::io::Write; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; -pub async fn setup(nick: String, token: String, channel: Vec) -> Result<()> { +pub async fn setup( + nick: String, + token: String, + channel: Vec, + running: Arc, +) -> Result<()> { let credentials = match Some(nick).zip(Some(token)) { Some((nick, token)) => tmi::client::Credentials::new(nick, token), None => tmi::client::Credentials::anon(), @@ -21,11 +28,15 @@ pub async fn setup(nick: String, token: String, channel: Vec) -> Result< client.join_all(&channels).await?; log::info!("Joined the following channels: {}", channels.join(", ")); - run(client, channels).await + run(client, channels, running).await } -async fn run(mut client: tmi::Client, channels: Vec) -> Result<()> { - loop { +async fn run( + mut client: tmi::Client, + channels: Vec, + running: Arc, +) -> Result<()> { + while running.load(Ordering::SeqCst) { let msg = client.recv().await?; match msg.as_typed()? { tmi::Message::Privmsg(msg) => on_msg(&mut client, msg).await?, @@ -37,6 +48,7 @@ async fn run(mut client: tmi::Client, channels: Vec) -> Result<()> _ => {} }; } + Ok(()) } async fn on_msg(client: &mut tmi::Client, msg: tmi::Privmsg<'_>) -> Result<()> {