From c7ded363cdb4217dbf83ddbea81143123c16a03e Mon Sep 17 00:00:00 2001 From: Abdulrahman Salah Date: Sun, 17 Jul 2022 02:54:56 +0000 Subject: [PATCH] perf: Improvements --- .gitpod.yml | 2 +- Cargo.lock | 22 ++++++++ Cargo.toml | 2 +- Makefile | 6 ++ src/config.rs | 79 +++++++++++++++------------ src/database/postgres.rs | 10 ++-- src/database/redis.rs | 4 +- src/gateway/client.rs | 10 ++-- src/main.rs | 27 ++++++--- src/routes/auth/sessions/delete.rs | 2 - src/routes/auth/sessions/fetch.rs | 2 - src/routes/channels/delete.rs | 2 +- src/routes/channels/edit.rs | 14 ++--- src/routes/channels/kick.rs | 5 +- src/routes/invites/create.rs | 2 +- src/routes/messages/create.rs | 7 +-- src/routes/messages/delete.rs | 8 +-- src/routes/messages/edit.rs | 4 +- src/routes/messages/fetch.rs | 7 +-- src/routes/servers/channels/create.rs | 2 +- src/routes/servers/channels/delete.rs | 2 +- src/routes/servers/channels/edit.rs | 10 ++-- src/routes/servers/edit.rs | 10 ++-- src/routes/servers/invites/delete.rs | 2 +- src/routes/servers/members/edit.rs | 6 +- src/routes/servers/members/kick.rs | 2 +- src/routes/servers/roles/create.rs | 12 ++-- src/routes/servers/roles/delete.rs | 2 +- src/routes/servers/roles/edit.rs | 22 ++------ src/structures/bot.rs | 10 +--- src/structures/channel.rs | 39 ++----------- src/structures/invite.rs | 11 +--- src/structures/member.rs | 23 ++------ src/structures/message.rs | 11 +--- src/structures/role.rs | 10 +--- src/structures/server.rs | 14 +---- src/structures/session.rs | 12 +--- src/structures/user.rs | 11 +--- src/utils/badges.rs | 26 ++++----- src/utils/error.rs | 7 ++- src/utils/permissions.rs | 59 ++++++++++---------- 41 files changed, 225 insertions(+), 293 deletions(-) diff --git a/.gitpod.yml b/.gitpod.yml index bbe0c64..ff9b41d 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -1,6 +1,6 @@ tasks: - init: make dev - command: cargo test + command: make test vscode: extensions: - streetsidesoftware.code-spell-checker diff --git a/Cargo.lock b/Cargo.lock index 6c514c3..c28708f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,7 @@ dependencies = [ "fred", "futures", "governor", + "inter-struct", "lazy_static", "log", "nanoid", @@ -1042,6 +1043,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "inter-struct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322eece48523a09060b7572296a3423bb6fc963462cfc8bf5fc6117510e131cf" +dependencies = [ + "inter-struct-codegen", +] + +[[package]] +name = "inter-struct-codegen" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58ce21bc89c44afac3bdf0cb5c363b4b91cef34e14aa612f7a17bc55f42f1ab7" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ipnet" version = "2.5.0" diff --git a/Cargo.toml b/Cargo.toml index 3d6a5ee..59557a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ futures = "0.3" ctor = "0.1.22" rs-snowflake = "0.6.0" chrono = { version = "0.4.19", features = ["serde"] } - +inter-struct = "0.2.0" # Docs opg = { git = "https://github.com/abdulrahman1s/opg", rev = "24f72e7cf09da7cd61b71aedaa14383502559612" } diff --git a/Makefile b/Makefile index 264446a..79f72d1 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,9 @@ +start: + cargo run + +test: + cargo test -- -Z unstable-options --report-time + dev: docker run -e POSTGRES_PASSWORD=postgres -p 5432:5432 -d postgres docker run -p 6379:6379 -d eqalpha/keydb \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index df5e394..ba1a7ea 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,46 +1,53 @@ -use std::env; - -fn is_true(mut v: String) -> bool { - v = v.to_lowercase(); - v == "true" || v == "yes" +macro_rules! config { + ($($name:ident $t:tt $($default:expr)?),+) => { + lazy_static! { + $( + pub static ref $name: $t = std::env::var(stringify!($name)) + .unwrap_or_else(|_| { + $( if true { return $default.to_string(); } )? + panic!("{} is required", stringify!($name)); + }) + .parse::<$t>() + .unwrap(); + )+ + } + }; } -macro_rules! get { - ($key:expr) => {{ - env::var($key).expect(&format!("{} is required", $key)) - }}; - ($key:expr, $default: expr) => {{ - env::var($key).unwrap_or($default.to_string()) - }}; -} +config! { + // Networking + PORT u32 8080, + TRUST_CLOUDFLARE bool false, + + // Database + DATABASE_URI String "postgres://postgres:postgres@localhost", + REDIS_URI String "redis://localhost:6379", + REDIS_POOL_SIZE usize 100, + DATABASE_POOL_SIZE u32 100, + + // Captcha + CAPTCHA_ENABLED bool false, + CAPTCHA_TOKEN String, + CAPTCHA_KEY String, -lazy_static! { - pub static ref DATABASE_URI: String = get!("DATABASE_URI", "postgres://postgres:postgres@localhost:5432"); - pub static ref REDIS_URI: String = get!("REDIS_URI", "redis://localhost:6379"); - pub static ref REDIS_POOL_SIZE: usize = get!("REDIS_POOL_SIZE", "100").parse().unwrap(); - pub static ref CAPTCHA_ENABLED: bool = is_true(get!("CAPTCHA_ENABLED", "false")); - pub static ref CAPTCHA_TOKEN: String = get!("CAPTCHA_TOKEN"); - pub static ref CAPTCHA_KEY: String = get!("CAPTCHA_KEY"); - pub static ref PORT: String = get!("PORT", "8080"); - pub static ref EMAIL_VERIFICATION: bool = is_true(get!("EMAIL_VERIFICATION", "false")); - pub static ref REQUIRE_INVITE_TO_REGISTER: bool = - is_true(get!("REQUIRE_INVITE_TO_REGISTER", "false")); - pub static ref SENDINBLUE_API_KEY: String = get!("SENDINBLUE_API_KEY"); - pub static ref TRUST_CLOUDFLARE: bool = is_true(get!("TRUST_CLOUDFLARE", "false")); + // Email + SENDINBLUE_API_KEY String, + EMAIL_VERIFICATION bool false, + REQUIRE_INVITE_TO_REGISTER bool false, // User related - pub static ref MAX_FRIENDS: u64 = get!("MAX_FRIENDS", "1000").parse().unwrap(); - pub static ref MAX_BLOCKED: u64 = get!("MAX_BLOCKED", "1000").parse().unwrap(); - pub static ref MAX_FRIEND_REQUESTS: u64 = get!("MAX_FRIEND_REQUESTS", "100").parse().unwrap(); + MAX_FRIENDS u64 1000, + MAX_BLOCKED u64 1000, + MAX_FRIEND_REQUESTS u64 100, // Group related - pub static ref MAX_GROUPS: u64 = get!("MAX_GROUPS", "100").parse().unwrap(); - pub static ref MAX_GROUP_MEMBERS: u64 = get!("MAX_GROUP_MEMBERS", "50").parse().unwrap(); + MAX_GROUPS u64 100, + MAX_GROUP_MEMBERS u64 50, // Server related - pub static ref MAX_SERVERS: u64 = get!("MAX_SERVERS", "100").parse().unwrap(); - pub static ref MAX_SERVER_MEMBERS: u64 = get!("MAX_SERVER_MEMBERS", "10000").parse().unwrap(); - pub static ref MAX_SERVER_CHANNELS: u64 = get!("MAX_SERVER_CHANNELS", "200").parse().unwrap(); - pub static ref MAX_SERVER_ROLES: u64 = get!("MAX_SERVER_ROLES", "200").parse().unwrap(); - pub static ref MAX_SERVER_EMOJIS: u64 = get!("MAX_SERVER_EMOJIS", "150").parse().unwrap(); + MAX_SERVERS u64 100, + MAX_SERVER_MEMBERS u64 10000, + MAX_SERVER_CHANNELS u64 100, + MAX_SERVER_ROLES u64 100, + MAX_SERVER_EMOJIS u64 150 } diff --git a/src/database/postgres.rs b/src/database/postgres.rs index d8dbb0d..5773e6a 100644 --- a/src/database/postgres.rs +++ b/src/database/postgres.rs @@ -1,12 +1,12 @@ -use crate::config::DATABASE_URI; +use crate::config::*; use once_cell::sync::OnceCell; use sqlx::{postgres::PgPoolOptions, Pool, Postgres}; -static DB: OnceCell> = OnceCell::new(); +static POOL: OnceCell> = OnceCell::new(); pub async fn connect() { let pool = PgPoolOptions::new() - .max_connections(100) + .max_connections(*DATABASE_POOL_SIZE) .connect(&*DATABASE_URI) .await .expect("Couldn't connect to database"); @@ -18,9 +18,9 @@ pub async fn connect() { .await .expect("Failed to run the migration"); - DB.set(pool).unwrap(); + POOL.set(pool).unwrap(); } pub fn pool() -> &'static Pool { - DB.get().unwrap() + POOL.get().unwrap() } diff --git a/src/database/redis.rs b/src/database/redis.rs index 32b7748..c14ffd5 100644 --- a/src/database/redis.rs +++ b/src/database/redis.rs @@ -55,7 +55,7 @@ mod tests { let value: String = REDIS.get("hello").await.unwrap(); assert_eq!(value, "world"); - }) + }); } #[test] @@ -86,6 +86,6 @@ mod tests { let value: String = REDIS.get("hello").await.unwrap(); assert_eq!(value, "world"); - }) + }); } } diff --git a/src/gateway/client.rs b/src/gateway/client.rs index eeaf4e8..8c1c812 100644 --- a/src/gateway/client.rs +++ b/src/gateway/client.rs @@ -13,6 +13,10 @@ use futures::{ use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; +lazy_static! { + static ref DEFAULT_PERMISSION: Permissions = Permissions::all(); +} + pub struct Client { pub permissions: Mutex>, pub user: Mutex>, @@ -42,15 +46,13 @@ impl Client { let payload: Payload = serde_json::from_str(&payload.as_string().unwrap()).unwrap(); let mut permissions = self.permissions.lock().await; - let p = permissions - .get(&target_id) - .unwrap_or(&Permissions::ADMINISTRATOR); + let p = permissions.get(&target_id).unwrap_or(&DEFAULT_PERMISSION); match &payload { Payload::MessageCreate(_) | Payload::MessageUpdate(_) | Payload::MessageDelete(_) => { - if p.has(Permissions::VIEW_CHANNEL).is_err() { + if !p.contains(Permissions::VIEW_CHANNEL) { continue; } } diff --git a/src/main.rs b/src/main.rs index 8c5a266..3e53142 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,10 @@ use std::net::SocketAddr; #[tokio::main] async fn main() { dotenv::dotenv().ok(); - env_logger::builder().format_timestamp(None).init(); + + env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")) + .format_timestamp(None) + .init(); log::info!("Connecting to database..."); database::postgres::connect().await; @@ -49,22 +52,30 @@ async fn main() { .unwrap(); } +#[cfg(test)] pub mod tests { + use super::*; + use log::LevelFilter; use once_cell::sync::Lazy; use tokio::runtime::{Builder, Runtime}; + static RUNTIME: Lazy = + Lazy::new(|| Builder::new_multi_thread().enable_all().build().unwrap()); + pub fn run(f: F) -> F::Output { - static RT: Lazy = - Lazy::new(|| Builder::new_multi_thread().enable_all().build().unwrap()); - RT.block_on(f) + RUNTIME.block_on(f) } - #[cfg(test)] #[ctor::ctor] fn setup() { dotenv::dotenv().ok(); - env_logger::builder().format_timestamp(None).try_init().ok(); - run(super::database::postgres::connect()); - run(super::database::redis::connect()); + + env_logger::builder() + .filter_level(LevelFilter::Trace) + .format_timestamp(None) + .init(); + + run(database::postgres::connect()); + run(database::redis::connect()); } } diff --git a/src/routes/auth/sessions/delete.rs b/src/routes/auth/sessions/delete.rs index 8353b88..e5179f9 100644 --- a/src/routes/auth/sessions/delete.rs +++ b/src/routes/auth/sessions/delete.rs @@ -46,8 +46,6 @@ mod tests { ) .await .unwrap(); - - session.cleanup().await.unwrap(); }) } } diff --git a/src/routes/auth/sessions/fetch.rs b/src/routes/auth/sessions/fetch.rs index 7671664..0566cb0 100644 --- a/src/routes/auth/sessions/fetch.rs +++ b/src/routes/auth/sessions/fetch.rs @@ -35,8 +35,6 @@ mod tests { assert_eq!(results.0.len(), 1); fetch_one(Extension(user), Path(session.id)).await.unwrap(); - - session.cleanup().await.unwrap(); }) } } diff --git a/src/routes/channels/delete.rs b/src/routes/channels/delete.rs index bf6beac..fc1362a 100644 --- a/src/routes/channels/delete.rs +++ b/src/routes/channels/delete.rs @@ -7,7 +7,7 @@ pub async fn delete(Extension(user): Extension, Path(id): Path) -> Re let channel = id.channel(user.id.into()).await?; if channel.owner_id != Some(user.id) { - return Err(Error::MissingPermissions); + return Err(Error::MissingAccess); } channel.remove().await?; diff --git a/src/routes/channels/edit.rs b/src/routes/channels/edit.rs index 9e123c6..3e8ed0f 100644 --- a/src/routes/channels/edit.rs +++ b/src/routes/channels/edit.rs @@ -2,10 +2,12 @@ use crate::extractors::*; use crate::gateway::*; use crate::structures::*; use crate::utils::*; +use inter_struct::prelude::*; use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate, OpgModel)] +#[derive(Deserialize, Validate, OpgModel, StructMerge)] +#[struct_merge("crate::structures::channel::Channel")] pub struct EditGroupOptions { #[validate(length(min = 3, max = 32))] name: Option, @@ -18,13 +20,11 @@ pub async fn edit( ) -> Result> { let mut group = id.channel(user.id.into()).await?; - let permissions = Permissions::fetch(&user, None, group.id.into()).await?; + Permissions::fetch(&user, None, group.id.into()) + .await? + .has(&[Permissions::MANAGE_CHANNELS])?; - permissions.has(Permissions::MANAGE_CHANNELS)?; - - if let Some(name) = data.name { - group.name = name.into(); - } + group.merge(data); let group = group.update_all_fields(pool()).await?; diff --git a/src/routes/channels/kick.rs b/src/routes/channels/kick.rs index d6d5b66..22e1188 100644 --- a/src/routes/channels/kick.rs +++ b/src/routes/channels/kick.rs @@ -9,9 +9,10 @@ pub async fn kick( ) -> Result<()> { let target = target_id.user().await?; let mut group = group_id.channel(user.id.into()).await?; - let permissions = Permissions::fetch(&user, None, group.id.into()).await?; - permissions.has(Permissions::KICK_MEMBERS)?; + Permissions::fetch_cached(&user, None, Some(&group)) + .await? + .has(&[Permissions::KICK_MEMBERS])?; if let Some(recipients) = group.recipients.as_mut() { let exists = recipients diff --git a/src/routes/invites/create.rs b/src/routes/invites/create.rs index 4e90dcf..526236a 100644 --- a/src/routes/invites/create.rs +++ b/src/routes/invites/create.rs @@ -17,7 +17,7 @@ pub async fn create( Permissions::fetch(&user, channel.server_id, channel.id.into()) .await? - .has(Permissions::INVITE_OTHERS)?; + .has(&[Permissions::INVITE_OTHERS])?; let invite = Invite::new(user.id, channel.id, channel.server_id); diff --git a/src/routes/messages/create.rs b/src/routes/messages/create.rs index dc28896..a4586f2 100644 --- a/src/routes/messages/create.rs +++ b/src/routes/messages/create.rs @@ -18,10 +18,9 @@ pub async fn create( ValidatedJson(data): ValidatedJson, Path(channel_id): Path, ) -> Result> { - let permissions = Permissions::fetch(&user, None, channel_id.into()).await?; - - permissions.has(Permissions::VIEW_CHANNEL)?; - permissions.has(Permissions::SEND_MESSAGES)?; + Permissions::fetch(&user, None, channel_id.into()) + .await? + .has(&[Permissions::VIEW_CHANNEL, Permissions::SEND_MESSAGES])?; let mut msg = Message::new(channel_id, user.id); diff --git a/src/routes/messages/delete.rs b/src/routes/messages/delete.rs index 8db77a4..f443ed8 100644 --- a/src/routes/messages/delete.rs +++ b/src/routes/messages/delete.rs @@ -8,12 +8,12 @@ pub async fn delete( Path((channel_id, id)): Path<(i64, i64)>, ) -> Result<()> { let msg = id.message().await?; - let permissions = Permissions::fetch(&user, None, channel_id.into()).await?; - - permissions.has(Permissions::VIEW_CHANNEL)?; + let p = Permissions::fetch(&user, None, channel_id.into()).await?; if msg.author_id != user.id { - permissions.has(Permissions::MANAGE_MESSAGES)?; + p.has(&[Permissions::MANAGE_MESSAGES])?; + } else { + p.has(&[Permissions::MANAGE_MESSAGES, Permissions::VIEW_CHANNEL])?; } let attachment_ids: Vec = msg diff --git a/src/routes/messages/edit.rs b/src/routes/messages/edit.rs index 28ad23a..7971d54 100644 --- a/src/routes/messages/edit.rs +++ b/src/routes/messages/edit.rs @@ -20,12 +20,12 @@ pub async fn edit( let mut msg = id.message().await?; if msg.author_id != user.id || msg.channel_id != channel_id { - return Err(Error::MissingPermissions); + return Err(Error::MissingAccess); } Permissions::fetch(&user, None, channel_id.into()) .await? - .has(Permissions::VIEW_CHANNEL)?; + .has(&[Permissions::VIEW_CHANNEL])?; msg.content = data.content.into(); msg.edited_at = Some(Utc::now().naive_utc()); diff --git a/src/routes/messages/fetch.rs b/src/routes/messages/fetch.rs index c9c6694..2b90d35 100644 --- a/src/routes/messages/fetch.rs +++ b/src/routes/messages/fetch.rs @@ -12,10 +12,9 @@ pub async fn fetch_one( return Err(Error::MissingAccess); } - let permissions = Permissions::fetch(&user, None, channel_id.into()).await?; - - permissions.has(Permissions::VIEW_CHANNEL)?; - permissions.has(Permissions::READ_MESSAGE_HISTORY)?; + Permissions::fetch(&user, None, channel_id.into()) + .await? + .has(&[Permissions::VIEW_CHANNEL, Permissions::READ_MESSAGE_HISTORY])?; Ok(Json(msg)) } diff --git a/src/routes/servers/channels/create.rs b/src/routes/servers/channels/create.rs index 6aface7..2f55393 100644 --- a/src/routes/servers/channels/create.rs +++ b/src/routes/servers/channels/create.rs @@ -20,7 +20,7 @@ pub async fn create( ) -> Result> { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::MANAGE_CHANNELS)?; + .has(&[Permissions::MANAGE_CHANNELS])?; let count = Channel::count(&format!("server_id = {}", server_id)).await?; diff --git a/src/routes/servers/channels/delete.rs b/src/routes/servers/channels/delete.rs index b4a2cfc..e016c78 100644 --- a/src/routes/servers/channels/delete.rs +++ b/src/routes/servers/channels/delete.rs @@ -11,7 +11,7 @@ pub async fn delete( Permissions::fetch(&user, server_id.into(), id.into()) .await? - .has(Permissions::MANAGE_CHANNELS)?; + .has(&[Permissions::MANAGE_CHANNELS])?; channel.remove().await?; diff --git a/src/routes/servers/channels/edit.rs b/src/routes/servers/channels/edit.rs index db505e5..0202677 100644 --- a/src/routes/servers/channels/edit.rs +++ b/src/routes/servers/channels/edit.rs @@ -2,10 +2,12 @@ use crate::extractors::*; use crate::gateway::*; use crate::structures::*; use crate::utils::*; +use inter_struct::prelude::*; use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate, OpgModel)] +#[derive(Deserialize, Validate, OpgModel, StructMerge)] +#[struct_merge("crate::structures::channel::Channel")] pub struct EditServerChannelOptions { #[validate(length(min = 1, max = 32))] name: Option, @@ -18,13 +20,11 @@ pub async fn edit( ) -> Result> { Permissions::fetch(&user, server_id.into(), id.into()) .await? - .has(Permissions::MANAGE_CHANNELS)?; + .has(&[Permissions::MANAGE_CHANNELS])?; let mut channel = id.channel(None).await?; - if let Some(name) = data.name { - channel.name = name.into(); - } + channel.merge(data); let channel = channel.update_all_fields(pool()).await?; diff --git a/src/routes/servers/edit.rs b/src/routes/servers/edit.rs index c79d59e..6170807 100644 --- a/src/routes/servers/edit.rs +++ b/src/routes/servers/edit.rs @@ -2,10 +2,12 @@ use crate::extractors::*; use crate::gateway::*; use crate::structures::*; use crate::utils::*; +use inter_struct::prelude::*; use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate, OpgModel)] +#[derive(Deserialize, Validate, OpgModel, StructMerge)] +#[struct_merge("crate::structures::server::Server")] pub struct EditServerOptions { #[validate(length(min = 1, max = 50))] name: Option, @@ -20,11 +22,9 @@ pub async fn edit( Permissions::fetch_cached(&user, Some(&server), None) .await? - .has(Permissions::MANAGE_SERVER)?; + .has(&[Permissions::MANAGE_SERVER])?; - if let Some(name) = data.name { - server.name = name; - } + server.merge(data); let server = server.update_all_fields(pool()).await?; diff --git a/src/routes/servers/invites/delete.rs b/src/routes/servers/invites/delete.rs index 067fd53..9491a91 100644 --- a/src/routes/servers/invites/delete.rs +++ b/src/routes/servers/invites/delete.rs @@ -8,7 +8,7 @@ pub async fn delete( ) -> Result<()> { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::MANAGE_INVITES)?; + .has(&[Permissions::MANAGE_INVITES])?; id.invite(server_id.into()).await?.remove().await?; diff --git a/src/routes/servers/members/edit.rs b/src/routes/servers/members/edit.rs index b3f416c..dad2f5a 100644 --- a/src/routes/servers/members/edit.rs +++ b/src/routes/servers/members/edit.rs @@ -21,8 +21,8 @@ pub async fn edit( let p = Permissions::fetch(&user, server_id.into(), None).await?; if let Some(nickname) = &data.nickname { - p.has(Permissions::CHANGE_NICKNAME)?; - p.has(Permissions::MANAGE_NICKNAMES)?; + p.has(&[Permissions::CHANGE_NICKNAME, Permissions::MANAGE_NICKNAMES])?; + member.nickname = if nickname.is_empty() { None } else { @@ -31,7 +31,7 @@ pub async fn edit( } if let Some(ids) = &data.roles { - p.has(Permissions::MANAGE_ROLES)?; + p.has(&[Permissions::MANAGE_ROLES])?; let mut roles = Role::select() .filter("server_id = $1") diff --git a/src/routes/servers/members/kick.rs b/src/routes/servers/members/kick.rs index c97ed72..68f792a 100644 --- a/src/routes/servers/members/kick.rs +++ b/src/routes/servers/members/kick.rs @@ -10,7 +10,7 @@ pub async fn kick( if user.id != id { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::KICK_MEMBERS)?; + .has(&[Permissions::KICK_MEMBERS])?; } id.member(server_id).await?.remove().await?; diff --git a/src/routes/servers/roles/create.rs b/src/routes/servers/roles/create.rs index b49fb15..88bd8ad 100644 --- a/src/routes/servers/roles/create.rs +++ b/src/routes/servers/roles/create.rs @@ -3,10 +3,12 @@ use crate::extractors::*; use crate::gateway::*; use crate::structures::*; use crate::utils::*; +use inter_struct::prelude::*; use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate, OpgModel)] +#[derive(Deserialize, Validate, OpgModel, StructMerge)] +#[struct_merge("crate::structures::role::Role")] pub struct CreateRoleOptions { #[validate(length(min = 1, max = 32))] name: String, @@ -22,7 +24,7 @@ pub async fn create( ) -> Result> { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::MANAGE_ROLES)?; + .has(&[Permissions::MANAGE_ROLES])?; let count = Role::count(&format!("server_id = {}", server_id)).await?; @@ -30,11 +32,9 @@ pub async fn create( return Err(Error::MaximumRoles); } - let mut role = Role::new(data.name, server_id); + let mut role = Role::new(data.name.clone(), server_id); - role.permissions = data.permissions; - role.hoist = data.hoist; - role.color = data.color; + role.merge(data); let role = role.save().await?; diff --git a/src/routes/servers/roles/delete.rs b/src/routes/servers/roles/delete.rs index cbdc2b3..9214850 100644 --- a/src/routes/servers/roles/delete.rs +++ b/src/routes/servers/roles/delete.rs @@ -9,7 +9,7 @@ pub async fn delete( ) -> Result<()> { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::MANAGE_ROLES)?; + .has(&[Permissions::MANAGE_ROLES])?; id.role(server_id).await?.remove().await?; diff --git a/src/routes/servers/roles/edit.rs b/src/routes/servers/roles/edit.rs index bcbd153..dde3df5 100644 --- a/src/routes/servers/roles/edit.rs +++ b/src/routes/servers/roles/edit.rs @@ -2,10 +2,12 @@ use crate::extractors::*; use crate::gateway::*; use crate::structures::*; use crate::utils::*; +use inter_struct::prelude::*; use serde::Deserialize; use validator::Validate; -#[derive(Deserialize, Validate, OpgModel)] +#[derive(Deserialize, Validate, OpgModel, StructMerge)] +#[struct_merge("crate::structures::role::Role")] pub struct EditRoleOptions { #[validate(length(min = 1, max = 32))] name: Option, @@ -21,25 +23,11 @@ pub async fn edit( ) -> Result> { Permissions::fetch(&user, server_id.into(), None) .await? - .has(Permissions::MANAGE_ROLES)?; + .has(&[Permissions::MANAGE_ROLES])?; let mut role = id.role(server_id).await?; - if let Some(name) = &data.name { - role.name = name.clone(); - } - - if let Some(permissions) = data.permissions { - role.permissions = permissions; - } - - if let Some(hoist) = data.hoist { - role.hoist = hoist; - } - - if let Some(color) = data.color { - role.color = color; - } + role.merge(data); let role = role.update_all_fields(pool()).await?; diff --git a/src/structures/bot.rs b/src/structures/bot.rs index ae73fee..51dc138 100644 --- a/src/structures/bot.rs +++ b/src/structures/bot.rs @@ -36,13 +36,6 @@ impl Bot { bot } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.owner_id.user().await?.remove().await?; - Ok(()) - } } impl Base for Bot {} @@ -56,8 +49,7 @@ mod tests { fn create() { run(async { let bot = Bot::faker().await.save().await.unwrap(); - let bot = Bot::find_one(bot.id).await.unwrap(); - bot.cleanup().await.unwrap(); + Bot::find_one(bot.id).await.unwrap(); }); } } diff --git a/src/structures/channel.rs b/src/structures/channel.rs index 8ede076..7f12bef 100644 --- a/src/structures/channel.rs +++ b/src/structures/channel.rs @@ -238,30 +238,6 @@ impl Channel { _ => panic!("Unsupported type"), } } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - - if self.is_group() || self.is_dm() { - for id in self.recipients.as_ref().unwrap() { - id.user().await?.remove().await?; - } - - if self.owner_id.is_none() { - self.remove().await?; - } - } else if self.in_server() { - self.server_id - .unwrap() - .server(None) - .await? - .cleanup() - .await?; - } - - Ok(()) - } } impl Base for Channel {} @@ -279,8 +255,7 @@ mod tests { .save() .await .unwrap(); - let channel = Channel::find_one(channel.id).await.unwrap(); - channel.cleanup().await.unwrap(); + Channel::find_one(channel.id).await.unwrap(); }); } @@ -292,8 +267,7 @@ mod tests { .save() .await .unwrap(); - let channel = Channel::find_one(channel.id).await.unwrap(); - channel.cleanup().await.unwrap(); + Channel::find_one(channel.id).await.unwrap(); }); } @@ -305,8 +279,7 @@ mod tests { .save() .await .unwrap(); - let channel = Channel::find_one(channel.id).await.unwrap(); - channel.cleanup().await.unwrap(); + Channel::find_one(channel.id).await.unwrap(); }) } @@ -318,8 +291,7 @@ mod tests { .save() .await .unwrap(); - let channel = Channel::find_one(channel.id).await.unwrap(); - channel.cleanup().await.unwrap(); + Channel::find_one(channel.id).await.unwrap(); }) } @@ -331,8 +303,7 @@ mod tests { .save() .await .unwrap(); - let channel = Channel::find_one(channel.id).await.unwrap(); - channel.cleanup().await.unwrap(); + Channel::find_one(channel.id).await.unwrap(); }) } } diff --git a/src/structures/invite.rs b/src/structures/invite.rs index d7e2163..a02d864 100644 --- a/src/structures/invite.rs +++ b/src/structures/invite.rs @@ -49,14 +49,6 @@ impl Invite { invite } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.inviter_id.user().await?.remove().await?; - self.channel_id.channel(None).await?.cleanup().await?; - Ok(()) - } } impl Base for Invite {} @@ -70,8 +62,7 @@ mod tests { fn create() { run(async { let invite = Invite::faker().await.save().await.unwrap(); - let invite = Invite::find_one(invite.id).await.unwrap(); - invite.cleanup().await.unwrap(); + Invite::find_one(invite.id).await.unwrap(); }); } } diff --git a/src/structures/member.rs b/src/structures/member.rs index 68bd523..6c4f578 100644 --- a/src/structures/member.rs +++ b/src/structures/member.rs @@ -35,9 +35,9 @@ impl Member { } } - pub async fn fetch_roles(&self) -> Vec { + pub async fn fetch_roles(&self) -> Result, ormlite::Error> { if self.roles.is_empty() { - return vec![]; + return Ok(vec![]); } Role::select() @@ -45,7 +45,6 @@ impl Member { .bind(self.roles.clone()) .fetch_all(pool()) .await - .unwrap() } #[cfg(test)] @@ -59,14 +58,6 @@ impl Member { member } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.server_id.server(None).await?.remove().await?; - self.id.user().await?.remove().await?; - Ok(()) - } } impl Base for Member {} @@ -81,15 +72,13 @@ mod tests { run(async { let member = Member::faker().await.save().await.unwrap(); - let member = Member::select() + Member::select() .filter("id = $1 AND server_id = $2") .bind(member.id) .bind(member.server_id) .fetch_one(pool()) .await - .expect("Cannot fetch member after it get saved"); - - member.cleanup().await.unwrap(); + .unwrap(); }) } @@ -105,11 +94,9 @@ mod tests { member.roles.push(role.id); let member = member.save().await.unwrap(); - let roles = member.fetch_roles().await; + let roles = member.fetch_roles().await.unwrap(); assert_eq!(roles.len(), 1); - - member.cleanup().await.unwrap(); }); } } diff --git a/src/structures/message.rs b/src/structures/message.rs index f22ba99..33f492a 100644 --- a/src/structures/message.rs +++ b/src/structures/message.rs @@ -55,14 +55,6 @@ impl Message { message } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.author_id.user().await?.remove().await?; - self.channel_id.channel(None).await?.cleanup().await?; - Ok(()) - } } impl Base for Message {} @@ -76,8 +68,7 @@ mod tests { fn create() { run(async { let msg = Message::faker().await.save().await.unwrap(); - let msg = Message::find_one(msg.id).await.unwrap(); - msg.cleanup().await.unwrap(); + Message::find_one(msg.id).await.unwrap(); }); } } diff --git a/src/structures/role.rs b/src/structures/role.rs index 1046700..d439d72 100644 --- a/src/structures/role.rs +++ b/src/structures/role.rs @@ -36,13 +36,6 @@ impl Role { server.save().await.unwrap(); role } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.server_id.server(None).await?.cleanup().await?; - Ok(()) - } } impl Base for Role {} @@ -56,8 +49,7 @@ mod tests { fn create() { run(async { let role = Role::faker().await.save().await.unwrap(); - let role = Role::find_one(role.id).await.unwrap(); - role.cleanup().await.unwrap(); + Role::find_one(role.id).await.unwrap(); }) } } diff --git a/src/structures/server.rs b/src/structures/server.rs index 7f10dbe..0a9ea3b 100644 --- a/src/structures/server.rs +++ b/src/structures/server.rs @@ -79,13 +79,6 @@ impl Server { server } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.owner_id.user().await?.remove().await?; - Ok(()) - } } impl Base for Server {} @@ -99,10 +92,7 @@ mod tests { fn create() { run(async { let server = Server::faker().await.save().await.unwrap(); - let server = Server::find_one(server.id) - .await - .expect("Server not found after save"); - server.cleanup().await.unwrap(); - }) + Server::find_one(server.id).await.unwrap(); + }); } } diff --git a/src/structures/session.rs b/src/structures/session.rs index 882aa45..9d5a053 100644 --- a/src/structures/session.rs +++ b/src/structures/session.rs @@ -36,13 +36,6 @@ impl Session { session } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - use crate::utils::Ref; - self.user_id.user().await?.remove().await?; - Ok(()) - } } impl Base for Session {} @@ -56,8 +49,7 @@ mod tests { fn create() { run(async { let session = Session::faker().await.save().await.unwrap(); - let session = Session::find_one(session.id).await.unwrap(); - session.cleanup().await.unwrap(); - }) + Session::find_one(session.id).await.unwrap(); + }); } } diff --git a/src/structures/user.rs b/src/structures/user.rs index 3705813..a965fed 100644 --- a/src/structures/user.rs +++ b/src/structures/user.rs @@ -132,12 +132,6 @@ impl User { user.verified = true; user } - - #[cfg(test)] - pub async fn cleanup(self) -> Result<(), crate::utils::Error> { - self.remove().await?; - Ok(()) - } } impl Base for User {} @@ -151,8 +145,7 @@ mod tests { fn create() { run(async { let user = User::faker().save().await.unwrap(); - let user = User::find_one(user.id).await.unwrap(); - user.cleanup().await.unwrap(); - }) + User::find_one(user.id).await.unwrap(); + }); } } diff --git a/src/utils/badges.rs b/src/utils/badges.rs index 7b232a6..37dbb75 100644 --- a/src/utils/badges.rs +++ b/src/utils/badges.rs @@ -1,5 +1,8 @@ use bitflags::bitflags; -use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; +use serde::{ + de::{Error, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use sqlx::{ encode::IsNull, error::BoxDynError, @@ -9,18 +12,12 @@ use sqlx::{ use std::fmt; bitflags! { + #[derive(Default)] pub struct Badges: u64 { const STAFF = 1 << 1; const DEVELOPER = 1 << 2; const SUPPORTER = 1 << 3; const TRANSLATOR = 1 << 4; - const DEFAULT = 0; - } -} - -impl Default for Badges { - fn default() -> Self { - Badges::DEFAULT } } @@ -63,33 +60,30 @@ impl<'de> Visitor<'de> for BadgesVisitor { fn visit_string(self, v: String) -> Result where - E: serde::de::Error, + E: Error, { self.visit_u64(v.parse().map_err(E::custom)?) } fn visit_str(self, v: &str) -> Result where - E: serde::de::Error, + E: Error, { self.visit_u64(v.parse().map_err(E::custom)?) } fn visit_i64(self, v: i64) -> Result where - E: serde::de::Error, + E: Error, { self.visit_u64(v as u64) } fn visit_u64(self, v: u64) -> Result where - E: serde::de::Error, + E: Error, { - match Badges::from_bits(v) { - Some(bits) => Ok(bits), - _ => Err(E::custom("Invalid bits")), - } + Badges::from_bits(v).ok_or_else(|| E::custom("Invalid bits")) } } diff --git a/src/utils/error.rs b/src/utils/error.rs index 8720da7..1d01a79 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -1,3 +1,4 @@ +use super::Permissions; use crate::middlewares::ratelimit::RateLimitInfo; use axum::{ extract::rejection::JsonRejection, @@ -21,7 +22,11 @@ quick_error! { AccountVerificationRequired { display("You need to verify your account in order to perform this action") } InvalidToken { display("Unauthorized. Provide a valid token and try again") } EmailAlreadyInUse { display("This email already in use") } - MissingPermissions { display("You lack permissions to perform that action") } + + MissingPermissions(missing: Vec) { + display("You lack permissions to perform that action, missing: {:?}", missing) + } + EmptyMessage { display("Cannot send an empty message") } RequireInviteCode { display("You must have an invite code to perform this action") } InviteAlreadyTaken { display("This invite already used") } diff --git a/src/utils/permissions.rs b/src/utils/permissions.rs index a4a1ab8..f8dbfb7 100644 --- a/src/utils/permissions.rs +++ b/src/utils/permissions.rs @@ -1,7 +1,10 @@ use crate::structures::*; use crate::utils::{Error, Ref, Result}; use bitflags::bitflags; -use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; +use serde::{ + de::{Error as SerdeError, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use sqlx::{ encode::IsNull, error::BoxDynError, @@ -11,8 +14,8 @@ use sqlx::{ use std::fmt; bitflags! { + #[derive(Default)] pub struct Permissions: u64 { - const ADMINISTRATOR = 1 << 0; const VIEW_CHANNEL = 1 << 1; const SEND_MESSAGES = 1 << 2; const READ_MESSAGE_HISTORY = 1 << 3; @@ -28,19 +31,16 @@ bitflags! { const KICK_MEMBERS = 1 << 13; const CHANGE_NICKNAME = 1 << 14; const INVITE_OTHERS = 1 << 15; - const DEFAULT = 0; } } lazy_static! { - pub static ref DEFAULT_PERMISSION_DM: Permissions = Permissions::DEFAULT - | Permissions::VIEW_CHANNEL + pub static ref DEFAULT_PERMISSION_DM: Permissions = Permissions::VIEW_CHANNEL | Permissions::SEND_MESSAGES | Permissions::EMBED_LINKS | Permissions::UPLOAD_FILES | Permissions::READ_MESSAGE_HISTORY; - pub static ref DEFAULT_PERMISSION_EVERYONE: Permissions = Permissions::DEFAULT - | Permissions::VIEW_CHANNEL + pub static ref DEFAULT_PERMISSION_EVERYONE: Permissions = Permissions::VIEW_CHANNEL | Permissions::SEND_MESSAGES | Permissions::EMBED_LINKS | Permissions::UPLOAD_FILES @@ -53,13 +53,13 @@ impl Permissions { server: Option<&Server>, channel: Option<&Channel>, ) -> Result { - let mut p = Permissions::DEFAULT; + let mut p = Permissions::default(); if let Some(server) = server { - p.set(Permissions::ADMINISTRATOR, server.owner_id == user.id); + p.set(Permissions::all(), server.owner_id == user.id); p.insert(server.permissions); - if p.contains(Permissions::ADMINISTRATOR) { + if p.is_all() { return Ok(p); } @@ -94,7 +94,7 @@ impl Permissions { { // for group owners if channel.owner_id == Some(user.id) { - p.set(Permissions::ADMINISTRATOR, true); + p = Permissions::all(); return Ok(p); } @@ -153,9 +153,21 @@ impl Permissions { Permissions::fetch_cached(user, server.as_ref(), channel.as_ref()).await } - pub fn has(&self, bits: Permissions) -> Result<()> { - if !self.contains(Permissions::ADMINISTRATOR) && !self.contains(bits) { - return Err(Error::MissingPermissions); + pub fn has(self, bits: &[Permissions]) -> Result<()> { + if self.is_all() { + return Ok(()); + } + + let mut missing = vec![]; + + for &bit in bits { + if !self.contains(bit) { + missing.push(bit); + } + } + + if !missing.is_empty() { + return Err(Error::MissingPermissions(missing)); } Ok(()) @@ -181,12 +193,6 @@ impl<'r> Decode<'r, Postgres> for Permissions { } } -impl Default for Permissions { - fn default() -> Self { - Permissions::DEFAULT - } -} - impl Serialize for Permissions { fn serialize(&self, serializer: S) -> Result where @@ -207,33 +213,30 @@ impl<'de> Visitor<'de> for PermissionsVisitor { fn visit_string(self, v: String) -> Result where - E: serde::de::Error, + E: SerdeError, { self.visit_u64(v.parse().map_err(E::custom)?) } fn visit_str(self, v: &str) -> Result where - E: serde::de::Error, + E: SerdeError, { self.visit_u64(v.parse().map_err(E::custom)?) } fn visit_i64(self, v: i64) -> Result where - E: serde::de::Error, + E: SerdeError, { self.visit_u64(v as u64) } fn visit_u64(self, v: u64) -> Result where - E: serde::de::Error, + E: SerdeError, { - match Permissions::from_bits(v) { - Some(bits) => Ok(bits), - _ => Err(E::custom("Invalid bits")), - } + Permissions::from_bits(v).ok_or_else(|| E::custom("Invalid bits")) } }