diff --git a/src/conn_handler.rs b/src/conn_handler.rs index 6d3256d..310ca9e 100644 --- a/src/conn_handler.rs +++ b/src/conn_handler.rs @@ -1,7 +1,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{RwLock, Mutex}; use std::sync::Arc; -use tracing::{info, warn, error, debug}; +use tracing::{info, warn, error}; use std::collections::HashMap; use chrono::Utc; @@ -30,12 +30,17 @@ pub async fn handle_connection( break; } let message = String::from_utf8_lossy(&buf[..n]).to_string(); - debug!(message="Received message", msg=%message); + if message.trim() == "DISCONNECT" { + socket_guard.write_all(b"DISCONNECTED\n").await.unwrap(); + socket_guard.flush().await.unwrap(); + break; + } if server_password_correct { if message.starts_with("AUTH:") { if authenticated { - let _ = socket_guard.write_all(b"ALREADY_AUTHENTICATED\n").await; + socket_guard.write_all(b"ALREADY_AUTHENTICATED\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } else { let auth_parts: Vec<&str> = message.splitn(3, ':').collect(); if auth_parts.len() == 3 { @@ -48,19 +53,22 @@ pub async fn handle_connection( match verify_session(&config, &username, session_token).await { Ok(is_valid_session) => { if !is_valid_session { - let _ = socket_guard.write_all(b"AUTH_FAILED\n").await; + socket_guard.write_all(b"AUTH_FAILED\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } else { authenticated = true; { let mut users = active_users.write().await; users.insert(username.clone(), Arc::clone(&socket)); } - let _ = socket_guard.write_all(b"AUTH_SUCCESS\n").await; + socket_guard.write_all(b"AUTH_SUCCESS\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } }, Err(e) => { let error_message = format!("AUTH_ERROR:{}\n", e); - let _ = socket_guard.write_all(error_message.as_bytes()).await; + socket_guard.write_all(error_message.as_bytes()).await.unwrap(); + socket_guard.flush().await.unwrap(); }, } } else { @@ -70,16 +78,19 @@ pub async fn handle_connection( let mut users = active_users.write().await; users.insert(username.clone(), Arc::clone(&socket)); } - let _ = socket_guard.write_all(b"AUTH_SUCCESS\n").await; + socket_guard.write_all(b"AUTH_SUCCESS\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } else { warn!(target: "auth", "Invalid AUTH message"); - let _ = socket_guard.write_all(b"AUTH_INVALID\n").await; + socket_guard.write_all(b"AUTH_INVALID\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } } else if authenticated { if message.len() > 256 { - let _ = socket_guard.write_all(b"MESSAGE_TOO_LONG\n").await; + socket_guard.write_all(b"MESSAGE_TOO_LONG\n").await.unwrap(); + socket_guard.flush().await.unwrap(); continue; } if let Some((recipient, message)) = message.split_once(':') { @@ -91,22 +102,27 @@ pub async fn handle_connection( send_direct_message(&active_users, recipient, &full_message).await; } } else { - let _ = socket_guard.write_all(b"INVALID_MESSAGE_FORMAT\n").await; + socket_guard.write_all(b"INVALID_MESSAGE_FORMAT\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } else { - let _ = socket_guard.write_all(b"NOT_AUTHENTICATED\n").await; + socket_guard.write_all(b"NOT_AUTHENTICATED\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } else if !server_password_correct && config.server.protect_server { if message.starts_with("SERVER_PASS:") { let server_password = message.trim_start_matches("SERVER_PASS:").trim(); if server_password == config.server.server_password { server_password_correct = true; - let _ = socket_guard.write_all(b"SERVER_PASS_CORRECT\n").await; + socket_guard.write_all(b"SERVER_PASS_CORRECT\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } else { - let _ = socket_guard.write_all(b"SERVER_PASS_INCORRECT\n").await; + socket_guard.write_all(b"SERVER_PASS_INCORRECT\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } else { - let _ = socket_guard.write_all(b"SERVER_PASS_REQUIRED\n").await; + socket_guard.write_all(b"SERVER_PASS_REQUIRED\n").await.unwrap(); + socket_guard.flush().await.unwrap(); } } } @@ -150,6 +166,7 @@ async fn broadcast_message( if let Err(e) = client.write_all(message.as_bytes()).await { error!(target: "server", "Failed to send message: {}", e); } else { + let _ = client.flush().await; info!(target: "server", "Broadcasted message: {}", message); } }); @@ -161,11 +178,13 @@ async fn send_direct_message(active_users: &ActiveUsers, target: &str, message: if let Some(client) = active_users.get(target) { let client = client.clone(); let message = message.to_string(); + info!(target: "server", "Sending direct message: {}", message); tokio::spawn(async move { let mut client = client.lock().await; if let Err(e) = client.write_all(message.as_bytes()).await { error!(target: "server", "Failed to send direct message: {}", e); } else { + let _ = client.flush().await; info!(target: "server", "Sent direct message: {}", message); } });