diff --git a/Cargo.lock b/Cargo.lock index 3782c1a7..4a816998 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -584,6 +584,7 @@ dependencies = [ "dashmap", "deadpool-postgres", "flexi_logger", + "futures", "git-version", "lazy_static", "log", @@ -599,6 +600,7 @@ dependencies = [ "toml", "twilight", "url", + "uuid", ] [[package]] @@ -2142,6 +2144,16 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05e42f7c18b8f902290b009cde6d651262f956c98bc51bca4cd1d511c9cd85c7" +[[package]] +name = "uuid" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fde2f6a4bea1d6e007c4ad38c6839fa71cbb63b6dbf5b595aa38dc9b1093c11" +dependencies = [ + "rand", + "serde", +] + [[package]] name = "vcpkg" version = "0.2.8" diff --git a/Cargo.toml b/Cargo.toml index 170c3f8a..5a4039dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ chrono = "0.4" deadpool-postgres={version = "0.5", features=["config"]} dashmap = "3.11" flexi_logger = { version = "0.15", default_features = false, features = ["colors", "specfile", "ziplogs"] } +futures = "0.3" git-version = "0.3" lazy_static = "1.4" log = "0.4" @@ -28,6 +29,7 @@ toml = "0.5" twilight = { git = "https://github.com/AEnterprise/twilight/", branch="gearbot" } tokio-postgres = { version = "0.5", default_features = false } url = "2.1" +uuid = { version = "0.8", features = ["serde", "v4"] } [profile.dev] debug = 0 diff --git a/src/core/context.rs b/context copy.rs similarity index 100% rename from src/core/context.rs rename to context copy.rs diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 524bff27..fbcd0f77 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -4,6 +4,7 @@ use crate::commands::meta::nodes::CommandNode; use crate::{command, subcommands}; pub mod basic; pub mod meta; +pub mod moderation; static ROOT_NODE: OnceCell = OnceCell::new(); @@ -18,7 +19,8 @@ pub fn get_root() -> &'static CommandNode { command!("coinflip", basic::coinflip), command!("ping", basic::ping), command!("echo", basic::echo), - command!("about", basic::about) + command!("about", basic::about), + command!("userinfo", moderation::userinfo) )) .ok() .unwrap(); diff --git a/src/commands/moderation/mod.rs b/src/commands/moderation/mod.rs new file mode 100644 index 00000000..e06415f3 --- /dev/null +++ b/src/commands/moderation/mod.rs @@ -0,0 +1,3 @@ +mod userinfo; + +pub use userinfo::userinfo; diff --git a/src/commands/moderation/userinfo.rs b/src/commands/moderation/userinfo.rs new file mode 100644 index 00000000..a16ebc76 --- /dev/null +++ b/src/commands/moderation/userinfo.rs @@ -0,0 +1,195 @@ +use crate::core::Context; +use crate::parser::Parser; +use crate::utils::Emoji; +use crate::utils::{CommandError, Error}; +use crate::{utils, CommandResult}; +use chrono::{DateTime, Utc}; +use log::debug; +use serde::export::TryFrom; +use std::borrow::Borrow; +use std::sync::Arc; +use std::time::Duration; +use twilight::builders::embed::EmbedBuilder; +use twilight::model::channel::Message; +use twilight::model::id::ChannelId; +use twilight::model::user::UserFlags; + +pub async fn userinfo(ctx: Arc, msg: Message, mut parser: Parser) -> CommandResult { + if msg.guild_id.is_none() { + return Err(Error::CmdError(CommandError::NoDM)); + } + + let user = parser.get_user().await?; + + //set some things that are the same regardless + let mut content = "".to_string(); + + let mut builder = EmbedBuilder::new(); + let mut author_builder = builder + .author() + .name(format!("{}#{}", user.name, user.discriminator)); + if user.avatar.is_some() { + let avatar = user.avatar.as_ref().unwrap(); + let extension = if avatar.starts_with("a_") { + "gif" + } else { + "png" + }; + author_builder = author_builder.icon_url(format!( + "https://cdn.discordapp.com/avatars/{}/{}.{}", + user.id, + user.avatar.as_ref().unwrap(), + extension + )) + } + builder = author_builder.commit(); + + //add badges + let flags = match user.public_flags { + Some(flags) => flags, + None => { + // we already know for sure the user will exist + let user = ctx.http.user(user.id.0).await?.unwrap(); + //TODO insert in cache when possible + user.public_flags.unwrap() + } + }; + + if flags.contains(UserFlags::DISCORD_EMPLOYEE) { + content += Emoji::StaffBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::DISCORD_PARTNER) { + content += Emoji::PartnerBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::HYPESQUAD_EVENTS) { + content += Emoji::HypesquadEvents.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::BUG_HUNTER) { + content += Emoji::BugHunterBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::HOUSE_BRAVERY) { + content += Emoji::BraveryBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::HOUSE_BRILLIANCE) { + content += Emoji::BrillianceBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::HOUSE_BALANCE) { + content += Emoji::BalanceBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::BUG_HUNTER_LEVEL_2) { + content += Emoji::BugHunterLvl2Badge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::VERIFIED_BOT_DEVELOPER) { + content += Emoji::VerifiedBotDevBadge.for_chat(); + content += " "; + } + + if flags.contains(UserFlags::EARLY_SUPPORTER) { + content += Emoji::EarlySupporterBadge.for_chat(); + } + + content += if user.bot { + Emoji::Robot.for_chat() + } else { + "" + }; + + let created_at = utils::snowflake_timestamp(user.id.0); + + content += &format!( + "\n**User id**: {}\n**Account created on**: {}\n**Account Age**: {}\n\n", + user.id, + created_at.format("%A %d %B %Y (%T)"), + utils::age(created_at, Utc::now(), 2) + ); + + match ctx + .cache + .member(msg.guild_id.unwrap(), user.id) + .await + .unwrap() + { + Some(member) => { + if member.roles.first().is_some() { + let role = member.roles.first().unwrap().clone(); + let cached_role = ctx.cache.role(role).await?.unwrap(); + builder = builder.color(cached_role.color); + let (joined, ago) = match &member.joined_at { + Some(joined) => { + let joined = DateTime::from_utc( + DateTime::parse_from_str(joined, "%FT%T%.f%z") + .unwrap() + .naive_utc(), + Utc, + ); + ( + joined.format("%A %d %B %Y (%T)").to_string(), + utils::age(joined, Utc::now(), 2), + ) + } + None => ("Unknown".to_string(), "Unknown".to_string()), + }; + + let mut count = 0; + let mut roles = "".to_string(); + for role in &member.roles { + if count > 0 { + roles += ", "; + } + roles += &format!("<@&{}>", role.0); + count += 1; + if count == 3 { + roles += &format!(" and {} more", member.roles.len() - 3); + break; + } + } + + content += &format!( + "**Joined on**: {}\n**Been here for**: {}\n**Roles**:{}", + joined, ago, roles + ); + match &member.premium_since { + Some(s) => { + let since: DateTime = DateTime::from_utc( + DateTime::parse_from_str(&*s, "%FT%T%.f%z") + .unwrap() + .naive_utc(), + Utc, + ); + content += &format!("**Boosting this server since**: {}", since); + } + None => {} + } + } + } + None => { + builder = builder.color(0x00cea2); + } + } + + builder = builder.description(content); + + ctx.http + .create_message(msg.channel_id) + .content(format!("User information about <@!{}>", user.id)) + .embed(builder.build()) + .await?; + + Ok(()) +} diff --git a/src/core/context/cache.rs b/src/core/context/cache.rs new file mode 100644 index 00000000..cb65cd78 --- /dev/null +++ b/src/core/context/cache.rs @@ -0,0 +1,45 @@ +use crate::core::Context; +use crate::utils::ParseError::MemberNotFoundById; +use crate::utils::{Error, ParseError}; +use futures::channel::oneshot; +use log::debug; +use std::sync::Arc; +use twilight::http::error::Error::Response; +use twilight::http::error::ResponseError::{Client, Server}; +use twilight::http::error::{Error as HttpError, ResponseError}; +use twilight::model::gateway::payload::{MemberChunk, RequestGuildMembers}; +use twilight::model::gateway::presence::Presence; +use twilight::model::guild::Member; +use twilight::model::id::UserId; +use twilight::model::user::User; +use uuid::Uuid; + +impl Context { + pub async fn get_user(&self, user_id: UserId) -> Result, Error> { + match self.cache.user(user_id).await? { + Some(user) => Ok(user), + None => { + // let's see if we can get em from the api + let result = self.http.user(user_id.0).await; + //TODO: cache in redis + + match result { + Ok(u) => { + let user = u.unwrap(); // there isn't a codepath that can even give none for this atm + Ok(Arc::new(user)) + } + Err(error) => { + //2 options here: + //1) drill down 3 layers and get a headache trying to deal with moving and re-assembling errors to figure out the status code + //2) just get the string and find the code in there + if format!("{:?}", error).contains("status: 404") { + Err(Error::ParseError(ParseError::InvalidUserID(user_id.0))) + } else { + Err(Error::TwilightHttp(error)) + } + } + } + } + } + } +} diff --git a/src/core/context/database.rs b/src/core/context/database.rs new file mode 100644 index 00000000..e16c905a --- /dev/null +++ b/src/core/context/database.rs @@ -0,0 +1,299 @@ +use aes_gcm::aead::generic_array::GenericArray; +use aes_gcm::{ + aead::{Aead, NewAead}, + Aes256Gcm, +}; +use dashmap::mapref::one::Ref; +use log::info; +use postgres_types::Type; +use rand::{thread_rng, RngCore}; +use serde_json; +use twilight::model::{ + channel::message::{Message, MessageType}, + id::{ChannelId, GuildId, MessageId, UserId}, +}; + +use crate::utils::{Error, FetchError}; +use crate::{ + core::{Context, GuildConfig}, + EncryptionKey, +}; + +#[derive(Debug)] +pub struct UserMessage { + pub content: String, + pub author: UserId, + pub channel: ChannelId, + pub guild: GuildId, + pub msg_type: MessageType, + pub pinned: bool, +} + +impl Context { + pub async fn fetch_user_message(&self, id: MessageId) -> Result { + let client = self.pool.get().await?; + + let statement = client + .prepare_typed("SELECT * from message where id=$1", &[Type::INT8]) + .await?; + + let fetch_id = id.0 as i64; + + let rows = client.query(&statement, &[&fetch_id]).await?; + + if let Some(stored_msg) = rows.get(0) { + let encrypted_message: &[u8] = stored_msg.get(1); + let author: i64 = stored_msg.get(2); + let channel: i64 = stored_msg.get(3); + let guild_id = { + let raw: i64 = stored_msg.get(4); + GuildId(raw as u64) + }; + + let raw_msg_type: i16 = stored_msg.get(5); + let pinned = stored_msg.get(6); + + // TODO: This should exist in twilight via a TryFrom + let msg_type = match raw_msg_type as u8 { + 0 => MessageType::Regular, + 1 => MessageType::RecipientAdd, + 2 => MessageType::RecipientRemove, + 3 => MessageType::Call, + 4 => MessageType::ChannelNameChange, + 5 => MessageType::ChannelIconChange, + 6 => MessageType::ChannelMessagePinned, + 7 => MessageType::GuildMemberJoin, + 8 => MessageType::UserPremiumSub, + 9 => MessageType::UserPremiumSubTier1, + 10 => MessageType::UserPremiumSubTier2, + 11 => MessageType::UserPremiumSubTier3, + 12 => MessageType::ChannelFollowAdd, + 14 => MessageType::GuildDiscoveryDisqualified, + 15 => MessageType::GuildDiscoveryRequalified, + _ => unimplemented!(), + }; + + let msg_id = fetch_id as u64; + + let start = std::time::Instant::now(); + + let plaintext = { + let guild_key = self.get_guild_encryption_key(guild_id).await?; + + decrypt_bytes(encrypted_message, &guild_key, msg_id) + }; + + let finish = std::time::Instant::now(); + + info!( + "It took {}ms to decrypt the message!", + (finish - start).as_millis() + ); + + let plaintext_string = String::from_utf8_lossy(&plaintext).to_string(); + + let assembled_message = UserMessage { + content: plaintext_string, + author: (author as u64).into(), + channel: (channel as u64).into(), + guild: guild_id, + msg_type, + pinned, + }; + + Ok(assembled_message) + } else { + Err(FetchError::ShouldExist.into()) + } + } + + pub async fn insert_user_message(&self, msg: &Message, guild_id: GuildId) -> Result<(), Error> { + // All guilds need to have a config before anything can happen thanks to encryption. + let _ = self.get_config(guild_id).await?; + let client = self.pool.get().await?; + + let msg_id = msg.id.0 as i64; + let msg_type = msg.kind as i16; + let author_id = msg.author.id.0 as i64; + let channel_id = msg.channel_id.0 as i64; + let pinned = msg.pinned; + + let start = std::time::Instant::now(); + let ciphertext = { + let plaintext = msg.content.as_bytes(); + let guild_key = self.get_guild_encryption_key(guild_id).await?; + + encrypt_bytes(plaintext, &guild_key, msg_id as u64) + }; + + let finish_crypto = std::time::Instant::now(); + + info!( + "It took {}ms to encrypt the user message!", + (finish_crypto - start).as_millis() + ); + + let guild_id = guild_id.0 as i64; + let statement = client + .prepare_typed( + "INSERT INTO message (id, content, author_id, channel_id, guild_id, type, pinned) + VALUES ($1, $2, $3, $4, $5, $6, $7)", + &[ + Type::INT8, + Type::BYTEA, + Type::INT8, + Type::INT8, + Type::INT8, + Type::INT2, + Type::BOOL, + ], + ) + .await?; + + client + .execute( + &statement, + &[ + &msg_id, + &ciphertext, + &author_id, + &channel_id, + &guild_id, + &msg_type, + &pinned, + ], + ) + .await?; + + info!("Logged a user message!"); + + Ok(()) + } + + async fn get_guild_encryption_key(&self, guild_id: GuildId) -> Result { + let client = self.pool.get().await?; + + let fetch_id = guild_id.0 as i64; + + let statement = client + .prepare_typed( + "SELECT encryption_key from guildconfig where id=$1", + &[Type::INT8], + ) + .await?; + + let rows = client.query(&statement, &[&fetch_id]).await?; + + if let Some(ek) = rows.get(0) { + let ek_bytes = ek.get(0); + + let guild_key = { + let master_key = self.__get_master_key().unwrap(); + + let decrypted_gk_bytes = decrypt_bytes(ek_bytes, master_key, fetch_id as u64); + EncryptionKey::clone_from_slice(&decrypted_gk_bytes) + }; + + Ok(guild_key) + } else { + Err(FetchError::ShouldExist.into()) + } + } + + pub async fn get_config( + &self, + guild_id: GuildId, + ) -> Result, Error> { + match self.configs.get(&guild_id) { + Some(config) => Ok(config), + None => { + let client = self.pool.get().await?; + let statement = client + .prepare_typed("SELECT config from guildconfig where id=$1", &[Type::INT8]) + .await?; + + let rows = client.query(&statement, &[&(guild_id.0 as i64)]).await?; + + let config: GuildConfig = if rows.is_empty() { + let encrypted_guild_key = { + let mut csprng = thread_rng(); + // Each guild has its own encryption key. This allows us, in the event of a compromise of the master key, + // to simply re-encrypt the guild keys instead of millions of messages. + let mut guild_encryption_key = [0u8; 32]; + csprng.fill_bytes(&mut guild_encryption_key); + + let master_key = self.__get_master_key().unwrap(); + encrypt_bytes(&guild_encryption_key, master_key, guild_id.0 as u64) + }; + + let config = GuildConfig::default(); + info!("No config found for {}, inserting blank one", guild_id); + let statement = client + .prepare_typed( + "INSERT INTO guildconfig (id, config, encryption_key) VALUES ($1, $2, $3)", + &[Type::INT8, Type::JSON, Type::BYTEA], + ) + .await?; + client + .execute( + &statement, + &[ + &(guild_id.0 as i64), + &serde_json::to_value(&GuildConfig::default()).unwrap(), + &encrypted_guild_key, + ], + ) + .await?; + + config + } else { + serde_json::from_value(rows[0].get(0))? + }; + + self.configs.insert(guild_id, config); + Ok(self.configs.get(&guild_id).unwrap()) + } + } + } + + /// Returns the master key that is used to encrypt and decrypt guild keys. + fn __get_master_key(&self) -> Option<&EncryptionKey> { + if let Some(mk_bytes) = &self.__static_master_key { + let key = GenericArray::from_slice(mk_bytes); + Some(key) + } else { + None + } + } +} + +fn encrypt_bytes(plaintext: &[u8], key: &EncryptionKey, id: u64) -> Vec { + let aead = Aes256Gcm::new(*key); + + // Since nonce's only never need to be reused, and Discor's snowflakes for messages + // are unique, we can use the messasge id to construct the nonce with its 64 bits, and then + // pad the rest with zeros. + let mut nonce_bytes = [0u8; 12]; + let msg_id_bytes = id.to_le_bytes(); + nonce_bytes[..8].copy_from_slice(&msg_id_bytes); + nonce_bytes[8..].copy_from_slice(&[0u8; 4]); + + let nonce = GenericArray::from_slice(&nonce_bytes); + + aead.encrypt(&nonce, plaintext) + .expect("Failed to encrypt an object!") +} + +fn decrypt_bytes(ciphertext: &[u8], key: &EncryptionKey, id: u64) -> Vec { + let aead = Aes256Gcm::new(*key); + + let mut nonce_bytes = [0u8; 12]; + let msg_id_bytes = id.to_le_bytes(); + nonce_bytes[..8].copy_from_slice(&msg_id_bytes); + nonce_bytes[8..].copy_from_slice(&[0u8; 4]); + + let nonce = GenericArray::from_slice(&nonce_bytes); + + aead.decrypt(&nonce, ciphertext) + .expect("Failed to decrypt an object!") +} diff --git a/src/core/context/mod.rs b/src/core/context/mod.rs new file mode 100644 index 00000000..51d61a44 --- /dev/null +++ b/src/core/context/mod.rs @@ -0,0 +1,64 @@ +use crate::core::context::stats::BotStats; +use crate::core::GuildConfig; +use dashmap::DashMap; +use deadpool_postgres::Pool; +use futures::channel::oneshot::Sender; +use std::sync::RwLock; +use twilight::cache::InMemoryCache; +use twilight::gateway::Cluster; +use twilight::http::Client as HttpClient; +use twilight::model::channel::Message; +use twilight::model::gateway::payload::MemberChunk; +use twilight::model::id::GuildId; +use twilight::model::user::CurrentUser; + +pub struct Context { + pub cache: InMemoryCache, + pub cluster: Cluster, + pub http: HttpClient, + pub stats: BotStats, + pub status_type: RwLock, + pub status_text: RwLock, + pub bot_user: CurrentUser, + configs: DashMap, + __static_master_key: Option>, + pool: Pool, + pub chunk_requests: DashMap>, +} + +impl Context { + pub fn new( + cache: InMemoryCache, + cluster: Cluster, + http: HttpClient, + bot_user: CurrentUser, + pool: Pool, + static_key: Option>, + ) -> Self { + Context { + cache, + cluster, + http, + stats: BotStats::default(), + status_type: RwLock::new(3), + status_text: RwLock::new(String::from("the commands turn")), + bot_user, + configs: DashMap::new(), + pool, + chunk_requests: DashMap::new(), + __static_master_key: static_key, + } + } + + /// Returns if a message was sent by us. + /// + /// Returns None if we couldn't currently get a lock on the cache, but + /// rarely, if ever should this happen. + pub fn is_own(&self, other: &Message) -> bool { + self.bot_user.id == other.author.id + } +} + +mod cache; +mod database; +mod stats; diff --git a/src/core/context/stats.rs b/src/core/context/stats.rs new file mode 100644 index 00000000..71c2483f --- /dev/null +++ b/src/core/context/stats.rs @@ -0,0 +1,74 @@ +use crate::{core::Context, GIT_VERSION}; +use chrono::{DateTime, Utc}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use twilight::model::channel::Message; + +#[derive(Debug)] +pub struct BotStats { + pub start_time: DateTime, + pub user_messages: AtomicUsize, + pub bot_messages: AtomicUsize, + pub my_messages: AtomicUsize, + pub error_count: AtomicUsize, + pub commands_ran: AtomicUsize, + pub custom_commands_ran: AtomicUsize, + pub guilds: AtomicUsize, + pub version: &'static str, +} + +impl BotStats { + pub async fn new_message(&self, ctx: &Context, msg: &Message) { + if msg.author.bot { + // This will simply skip incrementing it if we couldn't get + // a lock on the cache. No harm done. + if ctx.is_own(msg) { + ctx.stats.my_messages.fetch_add(1, Ordering::Relaxed); + } + ctx.stats.bot_messages.fetch_add(1, Ordering::Relaxed); + } else { + ctx.stats.user_messages.fetch_add(1, Ordering::Relaxed); + } + } + + pub async fn had_error(&self) { + self.error_count.fetch_add(1, Ordering::Relaxed); + } + + pub async fn new_guild(&self) { + self.guilds.fetch_add(1, Ordering::Relaxed); + } + + pub async fn left_guild(&self) { + self.guilds.fetch_sub(1, Ordering::Relaxed); + } + + pub async fn command_used(&self, is_custom: bool) { + if !is_custom { + self.commands_ran.fetch_add(1, Ordering::Relaxed); + } else { + self.custom_commands_ran.fetch_add(1, Ordering::Relaxed); + } + } +} + +impl Default for BotStats { + fn default() -> Self { + BotStats { + start_time: Utc::now(), + user_messages: AtomicUsize::new(0), + bot_messages: AtomicUsize::new(0), + my_messages: AtomicUsize::new(0), + error_count: AtomicUsize::new(0), + commands_ran: AtomicUsize::new(0), + custom_commands_ran: AtomicUsize::new(0), + guilds: AtomicUsize::new(0), + version: GIT_VERSION, + } + } +} + +#[derive(Debug)] +pub struct LoadingState { + to_load: u32, + loaded: u32, +} diff --git a/src/core/gearbot.rs b/src/core/gearbot.rs index c14814e6..e23c724b 100644 --- a/src/core/gearbot.rs +++ b/src/core/gearbot.rs @@ -131,7 +131,7 @@ async fn handle_event(event: (u64, Event), ctx: Arc) -> Result<(), Erro _ => {} } - commands::handle_event(event.1, ctx.clone()).await?; + commands::handle_event(event.0, event.1, ctx.clone()).await?; Ok(()) } diff --git a/src/core/handlers/commands.rs b/src/core/handlers/commands.rs index 0d2389b4..4a059c8b 100644 --- a/src/core/handlers/commands.rs +++ b/src/core/handlers/commands.rs @@ -1,16 +1,16 @@ use std::sync::Arc; -use log::info; +use log::debug; use twilight::gateway::cluster::Event; use crate::core::Context; use crate::parser::Parser; use crate::utils::Error; -pub async fn handle_event<'a>(event: Event, ctx: Arc) -> Result<(), Error> { +pub async fn handle_event<'a>(shard_id: u64, event: Event, ctx: Arc) -> Result<(), Error> { match event { Event::MessageCreate(msg) if !msg.author.bot => { - info!( + debug!( "Received a message from {}, saying {}", msg.author.name, msg.content ); @@ -38,7 +38,7 @@ pub async fn handle_event<'a>(event: Event, ctx: Arc) -> Result<(), Err }; if let Some(prefix) = prefix { - Parser::figure_it_out(&prefix, msg, ctx).await?; + Parser::figure_it_out(&prefix, msg, ctx, shard_id).await?; } } _ => (), diff --git a/src/core/handlers/general.rs b/src/core/handlers/general.rs index 4fa96659..709d229a 100644 --- a/src/core/handlers/general.rs +++ b/src/core/handlers/general.rs @@ -10,6 +10,7 @@ use twilight::model::gateway::{ use crate::core::Context; use crate::utils::Error; use crate::{gearbot_info, gearbot_warn}; +// use futures::SinkExt; pub async fn handle_event(shard_id: u64, event: &Event, ctx: Arc) -> Result<(), Error> { match &event { @@ -56,6 +57,21 @@ pub async fn handle_event(shard_id: u64, event: &Event, ctx: Arc) -> Re .await?; } Event::Resumed => gearbot_info!("Shard {} successfully resumed", shard_id), + Event::MemberChunk(chunk) => { + debug!("got a chunk with nonce {:?}", &chunk.nonce); + match &chunk.nonce { + Some(nonce) => { + debug!("waiter found: {}", ctx.chunk_requests.contains_key(nonce)); + match ctx.chunk_requests.remove(nonce) { + Some(mut waiter) => { + waiter.1.send(chunk.clone()); + } + None => {} + } + } + None => {} + }; + } _ => (), } Ok(()) diff --git a/src/core/handlers/modlog.rs b/src/core/handlers/modlog.rs index 23952254..caf2ad49 100644 --- a/src/core/handlers/modlog.rs +++ b/src/core/handlers/modlog.rs @@ -18,7 +18,7 @@ pub async fn handle_event(shard_id: u64, event: &Event, ctx: Arc) -> Re if let Ok(handle) = res { match handle { - Ok(_) => return Ok(()), + Ok(_) => {} Err(e) => return Err(Error::TwilightCluster(e)), } } diff --git a/src/core/mod.rs b/src/core/mod.rs index 3338974a..d638a719 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,5 +1,6 @@ pub use bot_config::BotConfig; pub use context::Context; +pub use context::*; pub use gearbot::GearBot; pub use guild_config::GuildConfig; diff --git a/src/parser/arguments.rs b/src/parser/arguments.rs deleted file mode 100644 index 8b137891..00000000 --- a/src/parser/arguments.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/parser/mod.rs b/src/parser/mod.rs index fa8f5205..5590ba57 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -6,22 +6,32 @@ use twilight::model::gateway::payload::MessageCreate; use crate::commands; use crate::commands::meta::nodes::CommandNode; use crate::core::Context; -use crate::utils::Error; +use crate::utils::{matchers, Error, ParseError}; +use twilight::model::{ + id::{GuildId, UserId}, + user::User, +}; -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Parser { pub parts: Vec, - index: usize, + pub index: usize, + ctx: Arc, + shard_id: u64, + guild_id: Option, } impl Parser { - fn new(content: &str) -> Self { + fn new(content: &str, ctx: Arc, shard_id: u64, guild_id: Option) -> Self { Parser { parts: content .split_whitespace() .map(String::from) .collect::>(), index: 0, + ctx, + shard_id, + guild_id, } } @@ -38,6 +48,7 @@ impl Parser { to_search = node; debug!("Found a command node: {}", node.get_name()); self.index += 1; + debug!("{}", self.index); nodes.push(node); } None => { @@ -55,15 +66,20 @@ impl Parser { prefix: &str, message: Box, ctx: Arc, + shard_id: u64, ) -> Result<(), Error> { //TODO: verify permissions - let parser = Parser::new(&message.0.content[prefix.len()..]); + let mut parser = Parser::new( + &message.0.content[prefix.len()..], + ctx.clone(), + shard_id, + message.guild_id, + ); debug!("Parser processing message: {:?}", &message.content); - //TODO: walk the stack to validate permissions let mut p = parser.clone(); - let command_nodes = p.get_command(); + let command_nodes = parser.get_command(); match command_nodes.last() { Some(node) => { let mut name = String::new(); @@ -75,11 +91,46 @@ impl Parser { } debug!("Executing command: {}", name); - node.execute(ctx.clone(), message.0, parser).await?; + p.index += command_nodes.len(); + node.execute(ctx.clone(), message.0, p).await?; ctx.stats.command_used(false).await; + Ok(()) } None => Ok(()), } } + + pub fn get_next(&mut self) -> Result<&str, Error> { + if self.index == self.parts.len() { + Err(Error::ParseError(ParseError::MissingArgument)) + } else { + let result = &self.parts[self.index]; + self.index += 1; + debug!("{}", self.index); + Ok(result) + } + } + + /// parses what comes next as discord user + pub async fn get_user(&mut self) -> Result, Error> { + let input = self.get_next()?; + let mention = matchers::get_mention(input); + match mention { + // we got a mention + Some(uid) => Ok(self.ctx.get_user(UserId(uid)).await?), + None => { + // is it a userid? + match input.parse::() { + Ok(uid) => Ok(self.ctx.get_user(UserId(uid)).await?), + Err(_) => { + //nope, must be a partial name + Err(Error::ParseError(ParseError::MemberNotFoundByName( + "not implemented yet".to_string(), + ))) + } + } + } + } + } } diff --git a/src/utils/emoji.rs b/src/utils/emoji.rs index e7423995..c4ab43ba 100644 --- a/src/utils/emoji.rs +++ b/src/utils/emoji.rs @@ -11,7 +11,19 @@ define_emoji!( Yes => "✅", No => "đŸšĢ", Info => "ℹī¸", - Warn => "⚠ī¸" + Warn => "⚠ī¸", + Robot => "🤖", + + StaffBadge => "", + PartnerBadge => "", + HypesquadEvents => "", + BraveryBadge => "", + BrillianceBadge => "", + BalanceBadge => "", + BugHunterBadge => "", + EarlySupporterBadge => "", + BugHunterLvl2Badge => "", + VerifiedBotDevBadge => "" ); pub static EMOJI_OVERRIDES: OnceCell> = OnceCell::new(); diff --git a/src/utils/errors.rs b/src/utils/errors.rs index e976597e..caf9371c 100644 --- a/src/utils/errors.rs +++ b/src/utils/errors.rs @@ -1,6 +1,7 @@ use std::{error, fmt, io}; use deadpool_postgres::PoolError; +use serde::export::Formatter; use twilight::cache::twilight_cache_inmemory; use twilight::gateway::cluster; use twilight::http; @@ -26,11 +27,22 @@ pub enum Error { DatabaseMigration(String), UnknownEmoji(String), Serde(serde_json::error::Error), + ParseError(ParseError), } #[derive(Debug)] pub enum CommandError { - WrongArgCount { expected: u8, provided: u8 }, + // WrongArgCount { expected: u8, provided: u8 }, + NoDM, +} + +#[derive(Debug)] +pub enum ParseError { + MissingArgument, + MemberNotFoundById(u64), + MemberNotFoundByName(String), + MultipleMembersByName(String), + InvalidUserID(u64), } impl error::Error for CommandError {} @@ -38,21 +50,36 @@ impl error::Error for CommandError {} impl fmt::Display for CommandError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - CommandError::WrongArgCount { expected, provided } => { - if expected > provided { - write!( - f, - "Too many arguments were provided! Expected {}, but found {}", - expected, provided - ) - } else { - write!( - f, - "Not enough arguments were provided! Expected {}, but found {}", - expected, provided - ) - } - } + // CommandError::WrongArgCount { expected, provided } => { + // if expected > provided { + // write!( + // f, + // "Too many arguments were provided! Expected {}, but found {}", + // expected, provided + // ) + // } else { + // write!( + // f, + // "Not enough arguments were provided! Expected {}, but found {}", + // expected, provided + // ) + // } + // } + CommandError::NoDM => write!(f, "You can not use this command in DMs"), + } + } +} + +impl error::Error for ParseError {} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ParseError::MemberNotFoundById(id) =>write!(f, "no member with userid ``{}`` found on this server", id), + ParseError::MissingArgument => write!(f, "You are missing a requried argument"), + ParseError::MemberNotFoundByName(name) => write!(f, "There is nobody named ``{}`` on this server", name), + ParseError::MultipleMembersByName(name) => write!(f, "Multiple members who's name starts with ``{}`` found, please use their full name and discriminator", name), + ParseError::InvalidUserID(id) => write!(f, "``{}`` is not a valid discord userid", id) } } } @@ -103,6 +130,7 @@ impl fmt::Display for Error { Error::Pool(e) => write!(f, "An error occurred in the database pool: {}", e), Error::UnknownEmoji(e) => write!(f, "Unknown emoji: {}", e), Error::Serde(e) => write!(f, "Serde error: {}", e), + Error::ParseError(e) => write!(f, "{}", e), } } } diff --git a/src/utils/matchers.rs b/src/utils/matchers.rs index ed3102b4..7db540ad 100644 --- a/src/utils/matchers.rs +++ b/src/utils/matchers.rs @@ -1,4 +1,5 @@ use lazy_static::lazy_static; +use log::debug; use regex::{Match, Regex, RegexBuilder}; use url::{Host, Url}; @@ -27,6 +28,22 @@ pub fn contains_mention(msg: &str) -> bool { MENTION_MATCHER.is_match(msg) } +pub fn get_mention(msg: &str) -> Option { + debug!("{}", msg); + let captures = MENTION_MATCHER_SOLO.captures(msg); + debug!("{:?}", captures); + match captures { + Some(c) => Some( + c.get(1) + .map_or(Some(""), |m| Some(m.as_str())) + .unwrap() + .parse::() + .unwrap(), + ), + None => None, + } +} + pub fn contains_url(msg: &str) -> bool { URL_MATCHER.is_match(msg) } @@ -130,6 +147,10 @@ lazy_static! { static ref MENTION_MATCHER: Regex = { Regex::new(r"<@!?\d+>").unwrap() }; } +lazy_static! { + static ref MENTION_MATCHER_SOLO: Regex = { Regex::new(r"^<@!?(\d+)>$").unwrap() }; +} + lazy_static! { static ref URL_MATCHER: Regex = { RegexBuilder::new(r"((?:https?://)[a-z0-9]+(?:[-._][a-z0-9]+)*\.[a-z]{2,5}(?::[0-9]{1,5})?(?:/[^ \n<>]*)?)") diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a951d756..e1d7ef30 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,5 @@ +use chrono::{DateTime, NaiveDateTime, Utc}; + // Remove this when they are all used. #[allow(dead_code)] pub mod matchers; @@ -63,3 +65,83 @@ pub fn clean(msg: &str, markdown: bool, links: bool, emotes: bool, lookalikes: b msg } + +static DISCORD_EPOCH: i64 = 1420070400000; + +pub fn snowflake_timestamp(snowflake: u64) -> DateTime { + DateTime::from_utc( + NaiveDateTime::from_timestamp(((snowflake as i64 >> 22) + DISCORD_EPOCH) / 1000, 0), + Utc, + ) +} + +pub fn age(old: DateTime, new: DateTime, max_parts: i8) -> String { + let mut seconds = new.signed_duration_since(old).num_seconds(); + let mut parts = 0; + let mut output = "".to_string(); + + let years = (seconds as f64 / (60.0 * 60.0 * 24.0 * 365.25)) as i64; + if years > 0 { + seconds -= (years as f64 * 60.0 * 60.0 * 24.0 * 365.25) as i64; + output += &format!("{} years ", years); + parts += 1; + + if parts == max_parts { + return output; + } + } + + let months = seconds / (60 * 60 * 24 * 30); + if months > 0 { + seconds -= months * 60 * 60 * 24 * 30; + output += &format!("{} months ", months); + parts += 1; + + if parts == max_parts { + return output; + } + } + + let weeks = seconds / (60 * 60 * 24 * 7); + if weeks > 0 { + seconds -= weeks * 60 * 60 * 24 * 7; + output += &format!("{} weeks ", weeks); + parts += 1; + if parts == max_parts { + return output; + } + } + + let days = seconds / (60 * 60 * 24); + if days > 0 { + seconds -= days * 60 * 60 * 24; + output += &format!("{} days ", days); + parts += 1; + if parts == max_parts { + return output; + } + } + + let hours = seconds / (60 * 60); + if hours > 0 { + seconds -= hours * 60 * 60; + output += &format!("{} hours ", hours); + parts += 1; + if parts == max_parts { + return output; + } + } + + let minutes = seconds / 60; + if minutes > 0 { + seconds -= minutes * 60; + output += &format!("{} minutes ", minutes); + parts += 1; + if parts == max_parts { + return output; + } + } + + output += &format!("{} seconds", seconds); + output +}