Skip to content

Commit

Permalink
Introduce IntoProtocol trait for cleaning up error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
w4 committed Jan 30, 2024
1 parent 85756a6 commit 80c1733
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 94 deletions.
9 changes: 4 additions & 5 deletions src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
events::{FetchAllUserChannelPermissions, SetUserChannelPermissions},
Persistence,
},
server::Server,
server::{response::IntoProtocol, Server},
};

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -441,7 +441,7 @@ impl Handler<ChannelJoin> for Channel {
}

// send the channel's topic to the joining user
for message in ChannelTopic::new(self).into_messages(self.name.to_string(), true) {
for message in ChannelTopic::new(self, true).into_messages(&self.name) {
msg.client.do_send(Broadcast {
message,
span: Span::current(),
Expand Down Expand Up @@ -497,8 +497,7 @@ impl Handler<ChannelUpdateTopic> for Channel {
});

for (client, connection) in &self.clients {
for message in ChannelTopic::new(self).into_messages(connection.nick.to_string(), false)
{
for message in ChannelTopic::new(self, false).into_messages(&connection.nick) {
client.do_send(Broadcast {
message,
span: Span::current(),
Expand Down Expand Up @@ -569,7 +568,7 @@ impl Handler<ChannelFetchTopic> for Channel {

#[instrument(parent = &msg.span, skip_all)]
fn handle(&mut self, msg: ChannelFetchTopic, _ctx: &mut Self::Context) -> Self::Result {
MessageResult(ChannelTopic::new(self))
MessageResult(ChannelTopic::new(self, msg.skip_on_none))
}
}

Expand Down
36 changes: 22 additions & 14 deletions src/channel/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,29 @@ use itertools::Itertools;
use crate::{
channel::{permissions::Permission, Channel, CurrentChannelTopic},
connection::InitiatedConnection,
server::response::IntoProtocol,
SERVER_NAME,
};

pub struct ChannelTopic {
pub channel_name: String,
pub topic: Option<CurrentChannelTopic>,
pub skip_on_none: bool,
}

impl ChannelTopic {
#[must_use]
pub fn new(channel: &Channel) -> Self {
pub fn new(channel: &Channel, skip_on_none: bool) -> Self {
Self {
channel_name: channel.name.to_string(),
topic: channel.topic.clone(),
skip_on_none,
}
}
}

#[must_use]
pub fn into_messages(self, for_user: String, skip_on_none: bool) -> Vec<Message> {
impl IntoProtocol for ChannelTopic {
fn into_messages(self, for_user: &str) -> Vec<Message> {
if let Some(topic) = self.topic {
vec![
Message {
Expand All @@ -43,21 +47,25 @@ impl ChannelTopic {
command: Command::Response(
Response::RPL_TOPICWHOTIME,
vec![
for_user,
for_user.to_string(),
self.channel_name.to_string(),
topic.set_by,
topic.set_time.timestamp().to_string(),
],
),
},
]
} else if !skip_on_none {
} else if !self.skip_on_none {
vec![Message {
tags: None,
prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
command: Command::Response(
Response::RPL_NOTOPIC,
vec![for_user, self.channel_name, "No topic is set".to_string()],
vec![
for_user.to_string(),
self.channel_name,
"No topic is set".to_string(),
],
),
}]
} else {
Expand All @@ -83,9 +91,10 @@ impl ChannelWhoList {
.collect(),
}
}
}

#[must_use]
pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for ChannelWhoList {
fn into_messages(self, for_user: &str) -> Vec<Message> {
let mut out = Vec::with_capacity(self.nick_list.len());

for (perm, conn) in self.nick_list {
Expand Down Expand Up @@ -233,18 +242,17 @@ pub enum ChannelJoinRejectionReason {
Banned,
}

impl ChannelJoinRejectionReason {
#[must_use]
pub fn into_message(self) -> Message {
impl IntoProtocol for ChannelJoinRejectionReason {
fn into_messages(self, for_user: &str) -> Vec<Message> {
match self {
Self::Banned => Message {
Self::Banned => vec![Message {
tags: None,
prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
command: Command::Response(
Response::ERR_BANNEDFROMCHAN,
vec!["Cannot join channel (+b)".to_string()],
vec![for_user.to_string(), "Cannot join channel (+b)".to_string()],
),
},
}],
}
}
}
Expand Down
64 changes: 22 additions & 42 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{
Persistence,
},
server::{
response::{NoSuchNick, WhoList},
response::{IntoProtocol, WhoList},
Server,
},
SERVER_NAME,
Expand Down Expand Up @@ -159,32 +159,27 @@ impl Client {
ctx: &mut Context<Self>,
channel: &Addr<Channel>,
message: M,
map: impl FnOnce(M::Result, &Self) -> Vec<Message> + 'static,
) where
M: actix::Message + Send + 'static,
M::Result: Send,
M::Result: Send + IntoProtocol,
Channel: Handler<M>,
<Channel as Actor>::Context: ToEnvelope<Channel, M>,
{
let fut = channel
.send(message)
.into_actor(self)
.map(move |result, ref mut this, _ctx| {
for message in (map)(result.unwrap(), this) {
for message in result.unwrap().into_messages(&this.connection.nick) {
this.writer.write(message);
}
});
ctx.spawn(fut);
}

fn server_send_map_write<M>(
&self,
ctx: &mut Context<Self>,
message: M,
map: impl FnOnce(M::Result, &Self) -> Vec<Message> + 'static,
) where
fn server_send_map_write<M>(&self, ctx: &mut Context<Self>, message: M)
where
M: actix::Message + Send + 'static,
M::Result: Send,
M::Result: Send + IntoProtocol,
Server: Handler<M>,
<Server as Actor>::Context: ToEnvelope<Server, M>,
{
Expand All @@ -193,7 +188,7 @@ impl Client {
.send(message)
.into_actor(self)
.map(move |result, ref mut this, _ctx| {
for message in (map)(result.unwrap(), this) {
for message in result.unwrap().into_messages(&this.connection.nick) {
this.writer.write(message);
}
});
Expand Down Expand Up @@ -304,7 +299,7 @@ impl Handler<ForceDisconnect> for Client {

fn handle(&mut self, _msg: ForceDisconnect, ctx: &mut Self::Context) -> Self::Result {
ctx.stop();
MessageResult(true)
MessageResult(Ok(()))
}
}

Expand Down Expand Up @@ -450,7 +445,9 @@ impl Handler<JoinChannelRequest> for Client {
Ok(v) => v,
Err(error) => {
error!(?error, "User failed to join channel");
this.writer.write(error.into_message());
for m in error.into_messages(&this.connection.nick) {
this.writer.write(m);
}
continue;
}
};
Expand Down Expand Up @@ -718,8 +715,10 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
self.channel_send_map_write(
ctx,
channel,
ChannelFetchTopic { span },
|res, this| res.into_messages(this.connection.nick.to_string(), false),
ChannelFetchTopic {
span,
skip_on_none: false,
},
);
}
}
Expand All @@ -740,9 +739,7 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
}
Command::LIST(_, _) => {
let span = Span::current();
self.server_send_map_write(ctx, ChannelList { span }, |res, this| {
res.into_messages(this.connection.nick.to_string())
});
self.server_send_map_write(ctx, ChannelList { span });
}
Command::INVITE(nick, channel) => {
let Some(channel) = self.channels.get(&channel) else {
Expand Down Expand Up @@ -800,15 +797,11 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
}
Command::MOTD(_) => {
let span = Span::current();
self.server_send_map_write(ctx, ServerFetchMotd { span }, |res, this| {
res.into_messages(this.connection.nick.to_string())
});
self.server_send_map_write(ctx, ServerFetchMotd { span });
}
Command::LUSERS(_, _) => {
let span = Span::current();
self.server_send_map_write(ctx, ServerListUsers { span }, |res, this| {
res.into_messages(&this.connection.nick)
});
self.server_send_map_write(ctx, ServerListUsers { span });
}
Command::VERSION(_) => {
self.writer.write(Message {
Expand Down Expand Up @@ -843,9 +836,7 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
}
Command::ADMIN(_) => {
let span = Span::current();
self.server_send_map_write(ctx, ServerAdminInfo { span }, |res, this| {
res.into_messages(&this.connection.nick)
});
self.server_send_map_write(ctx, ServerAdminInfo { span });
}
Command::INFO(_) => {
static INFO_STR: &str = include_str!("../text/info.txt");
Expand Down Expand Up @@ -874,15 +865,11 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
}
Command::WHO(Some(query), _) => {
let span = Span::current();
self.server_send_map_write(ctx, FetchWhoList { span, query }, |res, this| {
res.into_messages(&this.connection.nick)
});
self.server_send_map_write(ctx, FetchWhoList { span, query });
}
Command::WHOIS(Some(query), _) => {
let span = Span::current();
self.server_send_map_write(ctx, FetchWhois { span, query }, |res, this| {
res.into_messages(&this.connection.nick)
});
self.server_send_map_write(ctx, FetchWhois { span, query });
}
Command::WHOWAS(_, _, _) => {}
Command::KILL(nick, comment) => {
Expand Down Expand Up @@ -936,16 +923,9 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
ctx,
ForceDisconnect {
span,
user: user.to_string(),
user,
comment,
},
move |res, this| {
if res {
vec![]
} else {
NoSuchNick { nick: user }.into_messages(&this.connection.nick)
}
},
);
}
Command::AUTHENTICATE(_) => {
Expand Down
4 changes: 3 additions & 1 deletion src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
channel::Channel,
client::Client,
connection::{InitiatedConnection, UserId},
server::response::NoSuchNick,
};

/// Sent when a user is connecting to the server.
Expand Down Expand Up @@ -38,7 +39,7 @@ pub struct KillUser {
}

#[derive(Message, Clone)]
#[rtype(result = "bool")]
#[rtype(result = "Result<(), NoSuchNick>")]
pub struct ForceDisconnect {
pub span: Span,
pub user: String,
Expand Down Expand Up @@ -151,6 +152,7 @@ pub struct FetchUserPermission {
#[rtype(result = "super::channel::response::ChannelTopic")]
pub struct ChannelFetchTopic {
pub span: Span,
pub skip_on_none: bool,
}

/// Retrieves the WHO list for the channel.
Expand Down
9 changes: 5 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
UserNickChangeInternal, Wallops,
},
persistence::Persistence,
server::response::{AdminInfo, ListUsers, Motd, WhoList, Whois},
server::response::{AdminInfo, IntoProtocol, ListUsers, Motd, NoSuchNick, WhoList, Whois},
SERVER_NAME,
};

Expand Down Expand Up @@ -124,7 +124,7 @@ impl Handler<UserConnected> for Server {
});
}

for message in Motd::new(self).into_messages(msg.connection.nick.clone()) {
for message in Motd::new(self).into_messages(&msg.connection.nick) {
msg.handle.do_send(Broadcast {
span: Span::current(),
message,
Expand Down Expand Up @@ -313,9 +313,9 @@ impl Handler<ForceDisconnect> for Server {
fn handle(&mut self, msg: ForceDisconnect, _ctx: &mut Self::Context) -> Self::Result {
if let Some((handle, _)) = self.clients.iter().find(|(_, v)| v.nick == msg.user) {
handle.do_send(msg);
MessageResult(true)
MessageResult(Ok(()))
} else {
MessageResult(false)
MessageResult(Err(NoSuchNick { nick: msg.user }))
}
}
}
Expand Down Expand Up @@ -371,6 +371,7 @@ impl Handler<ChannelList> for Server {
.map(|channel| {
let fetch_topic = channel.send(ChannelFetchTopic {
span: Span::current(),
skip_on_none: true,
});

let fetch_members = channel.send(ChannelMemberList {
Expand Down
Loading

0 comments on commit 80c1733

Please sign in to comment.