Simplify error handling with anyhow

This commit is contained in:
John Doty 2022-10-09 08:21:03 -07:00
parent 82569743a3
commit 7f8e14384e
6 changed files with 90 additions and 138 deletions

28
Cargo.lock generated
View file

@ -17,6 +17,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anyhow"
version = "1.0.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.1.0" version = "1.1.0"
@ -124,8 +130,10 @@ dependencies = [
name = "fwd" name = "fwd"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"bytes", "bytes",
"procfs", "procfs",
"thiserror",
"tokio", "tokio",
] ]
@ -398,6 +406,26 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "time" name = "time"
version = "0.1.44" version = "0.1.44"

View file

@ -6,7 +6,9 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
anyhow = "1.0"
bytes = "1" bytes = "1"
thiserror = "1.0"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
[target.'cfg(target_os="linux")'.dependencies] [target.'cfg(target_os="linux")'.dependencies]

View file

@ -1,5 +1,5 @@
use crate::message::Message; use crate::message::Message;
use crate::Error; use anyhow::Result;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -30,11 +30,11 @@ async fn connection_read<T: AsyncRead + Unpin>(
channel: u64, channel: u64,
read: &mut T, read: &mut T,
writer: &mut mpsc::Sender<Message>, writer: &mut mpsc::Sender<Message>,
) -> Result<(), Error> { ) -> Result<(), tokio::io::Error> {
let result = loop { let result = loop {
let mut buffer = BytesMut::with_capacity(MAX_PACKET); let mut buffer = BytesMut::with_capacity(MAX_PACKET);
if let Err(e) = read.read_buf(&mut buffer).await { if let Err(e) = read.read_buf(&mut buffer).await {
break Err(Error::IO(e)); break Err(e);
} }
if buffer.len() == 0 { if buffer.len() == 0 {
@ -42,7 +42,9 @@ async fn connection_read<T: AsyncRead + Unpin>(
} }
if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await {
break Err(Error::ConnectionReset); break Err(tokio::io::Error::from(
tokio::io::ErrorKind::ConnectionReset,
));
} }
// TODO: Flow control here, wait for the packet to be acknowleged so // TODO: Flow control here, wait for the packet to be acknowleged so
@ -65,11 +67,9 @@ async fn connection_read<T: AsyncRead + Unpin>(
async fn connection_write<T: AsyncWrite + Unpin>( async fn connection_write<T: AsyncWrite + Unpin>(
data: &mut mpsc::Receiver<Bytes>, data: &mut mpsc::Receiver<Bytes>,
write: &mut T, write: &mut T,
) -> Result<(), Error> { ) -> Result<()> {
while let Some(buf) = data.recv().await { while let Some(buf) = data.recv().await {
if let Err(e) = write.write_all(&buf[..]).await { write.write_all(&buf[..]).await?;
return Err(Error::IO(e));
}
} }
Ok(()) Ok(())
} }

View file

@ -1,3 +1,4 @@
use anyhow::{bail, Result};
use connection::ConnectionTable; use connection::ConnectionTable;
use message::{Message, MessageReader, MessageWriter}; use message::{Message, MessageReader, MessageWriter};
use std::collections::HashMap; use std::collections::HashMap;
@ -12,63 +13,6 @@ mod connection;
mod message; mod message;
mod refresh; mod refresh;
#[derive(Debug)]
pub enum Error {
Protocol,
ProtocolVersion,
IO(tokio::io::Error),
MessageIncomplete,
MessageUnknown,
MessageCorrupt,
ConnectionReset,
ProcFs(String),
NotSupported,
}
impl PartialEq for Error {
fn eq(&self, other: &Error) -> bool {
use Error::*;
match self {
Protocol => match other {
Protocol => true,
_ => false,
},
ProtocolVersion => match other {
ProtocolVersion => true,
_ => false,
},
IO(s) => match other {
IO(o) => s.kind() == o.kind(),
_ => false,
},
MessageIncomplete => match other {
MessageIncomplete => true,
_ => false,
},
MessageUnknown => match other {
MessageUnknown => true,
_ => false,
},
MessageCorrupt => match other {
MessageCorrupt => true,
_ => false,
},
ConnectionReset => match other {
ConnectionReset => true,
_ => false,
},
ProcFs(a) => match other {
ProcFs(b) => a == b,
_ => false,
},
NotSupported => match other {
NotSupported => true,
_ => false,
},
}
}
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Write Management // Write Management
@ -92,7 +36,7 @@ impl PartialEq for Error {
async fn pump_write<T: AsyncWrite + Unpin>( async fn pump_write<T: AsyncWrite + Unpin>(
messages: &mut mpsc::Receiver<Message>, messages: &mut mpsc::Receiver<Message>,
writer: &mut MessageWriter<T>, writer: &mut MessageWriter<T>,
) -> Result<(), Error> { ) -> Result<()> {
while let Some(msg) = messages.recv().await { while let Some(msg) = messages.recv().await {
writer.write(msg).await?; writer.write(msg).await?;
} }
@ -125,7 +69,7 @@ async fn server_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>, reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<()> {
eprintln!("< Processing packets..."); eprintln!("< Processing packets...");
loop { loop {
let message = reader.read().await?; let message = reader.read().await?;
@ -179,7 +123,7 @@ async fn server_read<T: AsyncRead + Unpin>(
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>( async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>, reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>, writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> { ) -> Result<()> {
let connections = ConnectionTable::new(); let connections = ConnectionTable::new();
// The first message we send must be an announcement. // The first message we send must be an announcement.
@ -217,7 +161,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
} }
} }
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> { async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<()> {
// TODO: While we're waiting here we should be echoing everything we read. // TODO: While we're waiting here we should be echoing everything we read.
// We should also be proxying *our* stdin to the processes stdin, // We should also be proxying *our* stdin to the processes stdin,
// and turn that off when we've synchronized. That way we can // and turn that off when we've synchronized. That way we can
@ -225,11 +169,12 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
eprintln!("> Waiting for synchronization marker..."); eprintln!("> Waiting for synchronization marker...");
let mut seen = 0; let mut seen = 0;
while seen < 8 { while seen < 8 {
let byte = match reader.read_u8().await { let byte = reader.read_u8().await?;
Ok(b) => b, if byte == 0 {
Err(e) => return Err(Error::IO(e)), seen += 1;
}; } else {
seen = if byte == 0 { seen + 1 } else { 0 }; tokio::io::stdout().write_u8(byte).await?;
}
} }
Ok(()) Ok(())
} }
@ -261,19 +206,13 @@ async fn client_listen(
port: u16, port: u16,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<()> {
loop { loop {
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?;
Ok(t) => t,
Err(e) => return Err(Error::IO(e)),
};
loop { loop {
// The second item contains the IP and port of the new // The second item contains the IP and port of the new
// connection, but we don't care. // connection, but we don't care.
let (mut socket, _) = match listener.accept().await { let (mut socket, _) = listener.accept().await?;
Ok(s) => s,
Err(e) => return Err(Error::IO(e)),
};
let (writer, connections) = (writer.clone(), connections.clone()); let (writer, connections) = (writer.clone(), connections.clone());
tokio::spawn(async move { tokio::spawn(async move {
@ -287,7 +226,7 @@ async fn client_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>, reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<()> {
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new(); let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
eprintln!("> Processing packets..."); eprintln!("> Processing packets...");
@ -362,14 +301,14 @@ async fn client_read<T: AsyncRead + Unpin>(
async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>( async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>, reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>, writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> { ) -> Result<()> {
// Wait for the server's announcement. // Wait for the server's announcement.
if let Message::Hello(major, minor, _) = reader.read().await? { if let Message::Hello(major, minor, _) = reader.read().await? {
if major != 0 || minor > 1 { if major != 0 || minor > 1 {
return Err(Error::ProtocolVersion); bail!("Unsupported remote protocol version {}.{}", major, minor);
} }
} else { } else {
return Err(Error::Protocol); bail!("Expected a hello message from the remote server");
} }
// Kick things off with a listing of the ports... // Kick things off with a listing of the ports...
@ -436,16 +375,13 @@ pub async fn run_server() {
} }
} }
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, Error> { async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, std::io::Error> {
let mut cmd = process::Command::new("ssh"); let mut cmd = process::Command::new("ssh");
cmd.arg("-T").arg(server).arg("fwd").arg("--server"); cmd.arg("-T").arg(server).arg("fwd").arg("--server");
cmd.stdout(std::process::Stdio::piped()); cmd.stdout(std::process::Stdio::piped());
cmd.stdin(std::process::Stdio::piped()); cmd.stdin(std::process::Stdio::piped());
match cmd.spawn() { cmd.spawn()
Ok(t) => Ok(t),
Err(e) => Err(Error::IO(e)),
}
} }
pub async fn run_client(remote: &str) { pub async fn run_client(remote: &str) {

View file

@ -1,11 +1,20 @@
use crate::Error; use anyhow::Result;
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io::Cursor; use std::io::Cursor;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Messages // Messages
#[derive(Debug, Error)]
pub enum MessageError {
#[error("Message type unknown: {0}")]
Unknown(u8),
#[error("Message incomplete")]
Incomplete,
}
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct PortDesc { pub struct PortDesc {
pub port: u16, pub port: u16,
@ -83,7 +92,7 @@ impl Message {
}; };
} }
pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message, Error> { pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message> {
use Message::*; use Message::*;
match get_u8(cursor)? { match get_u8(cursor)? {
0x00 => Ok(Ping), 0x00 => Ok(Ping),
@ -127,48 +136,45 @@ impl Message {
let data = get_bytes(cursor, length.into())?; let data = get_bytes(cursor, length.into())?;
Ok(Data(channel, data)) Ok(Data(channel, data))
} }
_ => Err(Error::MessageUnknown), b => Err(MessageError::Unknown(b).into()),
} }
} }
} }
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, Error> { fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> {
if !cursor.has_remaining() { if !cursor.has_remaining() {
return Err(Error::MessageIncomplete); return Err(MessageError::Incomplete);
} }
Ok(cursor.get_u8()) Ok(cursor.get_u8())
} }
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, Error> { fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> {
if cursor.remaining() < 2 { if cursor.remaining() < 2 {
return Err(Error::MessageIncomplete); return Err(MessageError::Incomplete);
} }
Ok(cursor.get_u16()) Ok(cursor.get_u16())
} }
fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, Error> { fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, MessageError> {
if cursor.remaining() < 8 { if cursor.remaining() < 8 {
return Err(Error::MessageIncomplete); return Err(MessageError::Incomplete);
} }
Ok(cursor.get_u64()) Ok(cursor.get_u64())
} }
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, Error> { fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> {
if cursor.remaining() < length { if cursor.remaining() < length {
return Err(Error::MessageIncomplete); return Err(MessageError::Incomplete);
} }
Ok(cursor.copy_to_bytes(length)) Ok(cursor.copy_to_bytes(length))
} }
fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String, Error> { fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String> {
let length = get_u16(cursor)?; let length = get_u16(cursor)?;
let data = get_bytes(cursor, length.into())?; let data = get_bytes(cursor, length.into())?;
match std::str::from_utf8(&data[..]) { Ok(std::str::from_utf8(&data[..])?.to_owned())
Ok(s) => Ok(s.to_owned()),
Err(_) => return Err(Error::MessageCorrupt),
}
} }
fn slice_up_to(s: &str, max_len: usize) -> &str { fn slice_up_to(s: &str, max_len: usize) -> &str {
@ -198,13 +204,7 @@ impl<T: AsyncWrite + Unpin> MessageWriter<T> {
pub fn new(writer: T) -> MessageWriter<T> { pub fn new(writer: T) -> MessageWriter<T> {
MessageWriter { writer } MessageWriter { writer }
} }
pub async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { pub async fn write(self: &mut Self, msg: Message) -> Result<()> {
match self.write_impl(msg).await {
Err(e) => Err(Error::IO(e)),
Ok(ok) => Ok(ok),
}
}
async fn write_impl(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> {
// TODO: Optimize buffer usage please this is bad // TODO: Optimize buffer usage please this is bad
// eprintln!("? {:?}", msg); // eprintln!("? {:?}", msg);
let mut buffer = msg.encode(); let mut buffer = msg.encode();
@ -225,16 +225,10 @@ impl<T: AsyncRead + Unpin> MessageReader<T> {
pub fn new(reader: T) -> MessageReader<T> { pub fn new(reader: T) -> MessageReader<T> {
MessageReader { reader } MessageReader { reader }
} }
pub async fn read(self: &mut Self) -> Result<Message, Error> { pub async fn read(self: &mut Self) -> Result<Message> {
let frame_length = match self.reader.read_u32().await { let frame_length = self.reader.read_u32().await?;
Ok(l) => l,
Err(e) => return Err(Error::IO(e)),
};
let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap());
if let Err(e) = self.reader.read_buf(&mut data).await { self.reader.read_buf(&mut data).await?;
return Err(Error::IO(e));
}
let mut cursor = Cursor::new(&data[..]); let mut cursor = Cursor::new(&data[..]);
Message::decode(&mut cursor) Message::decode(&mut cursor)
@ -250,8 +244,8 @@ mod message_tests {
fn assert_round_trip(message: Message) { fn assert_round_trip(message: Message) {
let encoded = message.encode(); let encoded = message.encode();
let mut cursor = std::io::Cursor::new(&encoded[..]); let mut cursor = std::io::Cursor::new(&encoded[..]);
let result = Message::decode(&mut cursor); let result = Message::decode(&mut cursor).unwrap();
assert_eq!(Ok(message.clone()), result); assert_eq!(message.clone(), result);
let rt = tokio::runtime::Builder::new_current_thread() let rt = tokio::runtime::Builder::new_current_thread()
.enable_all() .enable_all()

View file

@ -1,21 +1,13 @@
use crate::message::PortDesc; use crate::message::PortDesc;
use crate::Error; use anyhow::{bail, Result};
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
pub fn get_entries() -> Result<Vec<PortDesc>, Error> { pub fn get_entries() -> Result<Vec<PortDesc>> {
Err(Error::NotSupported) bail!("Not supported on this operating system");
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
pub fn get_entries() -> Result<Vec<PortDesc>, Error> { pub fn get_entries() -> Result<Vec<PortDesc>> {
match get_entries_linux() {
Ok(v) => Ok(v),
Err(e) => Err(Error::ProcFs(format!("{:?}", e))),
}
}
#[cfg(target_os = "linux")]
pub fn get_entries_linux() -> procfs::ProcResult<Vec<PortDesc>> {
use procfs::process::FDTarget; use procfs::process::FDTarget;
use std::collections::HashMap; use std::collections::HashMap;