diff --git a/Cargo.lock b/Cargo.lock index 3b37c5c..ccac33e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" + [[package]] name = "autocfg" version = "1.1.0" @@ -124,8 +130,10 @@ dependencies = [ name = "fwd" version = "0.1.0" dependencies = [ + "anyhow", "bytes", "procfs", + "thiserror", "tokio", ] @@ -398,6 +406,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "time" version = "0.1.44" diff --git a/Cargo.toml b/Cargo.toml index 24c07f4..f45d97f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow = "1.0" bytes = "1" +thiserror = "1.0" tokio = { version = "1", features = ["full"] } [target.'cfg(target_os="linux")'.dependencies] diff --git a/src/connection.rs b/src/connection.rs index 9bf697d..360ea75 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,5 +1,5 @@ use crate::message::Message; -use crate::Error; +use anyhow::Result; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -30,11 +30,11 @@ async fn connection_read( channel: u64, read: &mut T, writer: &mut mpsc::Sender, -) -> Result<(), Error> { +) -> Result<(), tokio::io::Error> { let result = loop { let mut buffer = BytesMut::with_capacity(MAX_PACKET); if let Err(e) = read.read_buf(&mut buffer).await { - break Err(Error::IO(e)); + break Err(e); } if buffer.len() == 0 { @@ -42,7 +42,9 @@ async fn connection_read( } if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { - break Err(Error::ConnectionReset); + break Err(tokio::io::Error::from( + tokio::io::ErrorKind::ConnectionReset, + )); } // TODO: Flow control here, wait for the packet to be acknowleged so @@ -65,11 +67,9 @@ async fn connection_read( async fn connection_write( data: &mut mpsc::Receiver, write: &mut T, -) -> Result<(), Error> { +) -> Result<()> { while let Some(buf) = data.recv().await { - if let Err(e) = write.write_all(&buf[..]).await { - return Err(Error::IO(e)); - } + write.write_all(&buf[..]).await?; } Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index aaa61df..b17eeab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +use anyhow::{bail, Result}; use connection::ConnectionTable; use message::{Message, MessageReader, MessageWriter}; use std::collections::HashMap; @@ -12,63 +13,6 @@ mod connection; mod message; mod refresh; -#[derive(Debug)] -pub enum Error { - Protocol, - ProtocolVersion, - IO(tokio::io::Error), - MessageIncomplete, - MessageUnknown, - MessageCorrupt, - ConnectionReset, - ProcFs(String), - NotSupported, -} - -impl PartialEq for Error { - fn eq(&self, other: &Error) -> bool { - use Error::*; - match self { - Protocol => match other { - Protocol => true, - _ => false, - }, - ProtocolVersion => match other { - ProtocolVersion => true, - _ => false, - }, - IO(s) => match other { - IO(o) => s.kind() == o.kind(), - _ => false, - }, - MessageIncomplete => match other { - MessageIncomplete => true, - _ => false, - }, - MessageUnknown => match other { - MessageUnknown => true, - _ => false, - }, - MessageCorrupt => match other { - MessageCorrupt => true, - _ => false, - }, - ConnectionReset => match other { - ConnectionReset => true, - _ => false, - }, - ProcFs(a) => match other { - ProcFs(b) => a == b, - _ => false, - }, - NotSupported => match other { - NotSupported => true, - _ => false, - }, - } - } -} - // ---------------------------------------------------------------------------- // Write Management @@ -92,7 +36,7 @@ impl PartialEq for Error { async fn pump_write( messages: &mut mpsc::Receiver, writer: &mut MessageWriter, -) -> Result<(), Error> { +) -> Result<()> { while let Some(msg) = messages.recv().await { writer.write(msg).await?; } @@ -125,7 +69,7 @@ async fn server_read( reader: &mut MessageReader, writer: mpsc::Sender, connections: ConnectionTable, -) -> Result<(), Error> { +) -> Result<()> { eprintln!("< Processing packets..."); loop { let message = reader.read().await?; @@ -179,7 +123,7 @@ async fn server_read( async fn server_main( reader: &mut MessageReader, writer: &mut MessageWriter, -) -> Result<(), Error> { +) -> Result<()> { let connections = ConnectionTable::new(); // The first message we send must be an announcement. @@ -217,7 +161,7 @@ async fn server_main( } } -async fn client_sync(reader: &mut T) -> Result<(), Error> { +async fn client_sync(reader: &mut T) -> Result<()> { // TODO: While we're waiting here we should be echoing everything we read. // We should also be proxying *our* stdin to the processes stdin, // and turn that off when we've synchronized. That way we can @@ -225,11 +169,12 @@ async fn client_sync(reader: &mut T) -> Result<(), Error> eprintln!("> Waiting for synchronization marker..."); let mut seen = 0; while seen < 8 { - let byte = match reader.read_u8().await { - Ok(b) => b, - Err(e) => return Err(Error::IO(e)), - }; - seen = if byte == 0 { seen + 1 } else { 0 }; + let byte = reader.read_u8().await?; + if byte == 0 { + seen += 1; + } else { + tokio::io::stdout().write_u8(byte).await?; + } } Ok(()) } @@ -261,19 +206,13 @@ async fn client_listen( port: u16, writer: mpsc::Sender, connections: ConnectionTable, -) -> Result<(), Error> { +) -> Result<()> { loop { - let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { - Ok(t) => t, - Err(e) => return Err(Error::IO(e)), - }; + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; loop { // The second item contains the IP and port of the new // connection, but we don't care. - let (mut socket, _) = match listener.accept().await { - Ok(s) => s, - Err(e) => return Err(Error::IO(e)), - }; + let (mut socket, _) = listener.accept().await?; let (writer, connections) = (writer.clone(), connections.clone()); tokio::spawn(async move { @@ -287,7 +226,7 @@ async fn client_read( reader: &mut MessageReader, writer: mpsc::Sender, connections: ConnectionTable, -) -> Result<(), Error> { +) -> Result<()> { let mut listeners: HashMap> = HashMap::new(); eprintln!("> Processing packets..."); @@ -362,14 +301,14 @@ async fn client_read( async fn client_main( reader: &mut MessageReader, writer: &mut MessageWriter, -) -> Result<(), Error> { +) -> Result<()> { // Wait for the server's announcement. if let Message::Hello(major, minor, _) = reader.read().await? { if major != 0 || minor > 1 { - return Err(Error::ProtocolVersion); + bail!("Unsupported remote protocol version {}.{}", major, minor); } } else { - return Err(Error::Protocol); + bail!("Expected a hello message from the remote server"); } // Kick things off with a listing of the ports... @@ -436,16 +375,13 @@ pub async fn run_server() { } } -async fn spawn_ssh(server: &str) -> Result { +async fn spawn_ssh(server: &str) -> Result { let mut cmd = process::Command::new("ssh"); cmd.arg("-T").arg(server).arg("fwd").arg("--server"); cmd.stdout(std::process::Stdio::piped()); cmd.stdin(std::process::Stdio::piped()); - match cmd.spawn() { - Ok(t) => Ok(t), - Err(e) => Err(Error::IO(e)), - } + cmd.spawn() } pub async fn run_client(remote: &str) { diff --git a/src/message.rs b/src/message.rs index aaa8a9d..0f66a3f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,11 +1,20 @@ -use crate::Error; +use anyhow::Result; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::io::Cursor; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // ---------------------------------------------------------------------------- // Messages +#[derive(Debug, Error)] +pub enum MessageError { + #[error("Message type unknown: {0}")] + Unknown(u8), + #[error("Message incomplete")] + Incomplete, +} + #[derive(Debug, PartialEq, Clone)] pub struct PortDesc { pub port: u16, @@ -83,7 +92,7 @@ impl Message { }; } - pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result { + pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result { use Message::*; match get_u8(cursor)? { 0x00 => Ok(Ping), @@ -127,48 +136,45 @@ impl Message { let data = get_bytes(cursor, length.into())?; Ok(Data(channel, data)) } - _ => Err(Error::MessageUnknown), + b => Err(MessageError::Unknown(b).into()), } } } -fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { if !cursor.has_remaining() { - return Err(Error::MessageIncomplete); + return Err(MessageError::Incomplete); } Ok(cursor.get_u8()) } -fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { if cursor.remaining() < 2 { - return Err(Error::MessageIncomplete); + return Err(MessageError::Incomplete); } Ok(cursor.get_u16()) } -fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { if cursor.remaining() < 8 { - return Err(Error::MessageIncomplete); + return Err(MessageError::Incomplete); } Ok(cursor.get_u64()) } -fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { +fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { if cursor.remaining() < length { - return Err(Error::MessageIncomplete); + return Err(MessageError::Incomplete); } Ok(cursor.copy_to_bytes(length)) } -fn get_string(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_string(cursor: &mut Cursor<&[u8]>) -> Result { let length = get_u16(cursor)?; let data = get_bytes(cursor, length.into())?; - match std::str::from_utf8(&data[..]) { - Ok(s) => Ok(s.to_owned()), - Err(_) => return Err(Error::MessageCorrupt), - } + Ok(std::str::from_utf8(&data[..])?.to_owned()) } fn slice_up_to(s: &str, max_len: usize) -> &str { @@ -198,13 +204,7 @@ impl MessageWriter { pub fn new(writer: T) -> MessageWriter { MessageWriter { writer } } - pub async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { - match self.write_impl(msg).await { - Err(e) => Err(Error::IO(e)), - Ok(ok) => Ok(ok), - } - } - async fn write_impl(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> { + pub async fn write(self: &mut Self, msg: Message) -> Result<()> { // TODO: Optimize buffer usage please this is bad // eprintln!("? {:?}", msg); let mut buffer = msg.encode(); @@ -225,16 +225,10 @@ impl MessageReader { pub fn new(reader: T) -> MessageReader { MessageReader { reader } } - pub async fn read(self: &mut Self) -> Result { - let frame_length = match self.reader.read_u32().await { - Ok(l) => l, - Err(e) => return Err(Error::IO(e)), - }; - + pub async fn read(self: &mut Self) -> Result { + let frame_length = self.reader.read_u32().await?; let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); - if let Err(e) = self.reader.read_buf(&mut data).await { - return Err(Error::IO(e)); - } + self.reader.read_buf(&mut data).await?; let mut cursor = Cursor::new(&data[..]); Message::decode(&mut cursor) @@ -250,8 +244,8 @@ mod message_tests { fn assert_round_trip(message: Message) { let encoded = message.encode(); let mut cursor = std::io::Cursor::new(&encoded[..]); - let result = Message::decode(&mut cursor); - assert_eq!(Ok(message.clone()), result); + let result = Message::decode(&mut cursor).unwrap(); + assert_eq!(message.clone(), result); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/src/refresh.rs b/src/refresh.rs index ebff676..3f7275d 100644 --- a/src/refresh.rs +++ b/src/refresh.rs @@ -1,21 +1,13 @@ use crate::message::PortDesc; -use crate::Error; +use anyhow::{bail, Result}; #[cfg(not(target_os = "linux"))] -pub fn get_entries() -> Result, Error> { - Err(Error::NotSupported) +pub fn get_entries() -> Result> { + bail!("Not supported on this operating system"); } #[cfg(target_os = "linux")] -pub fn get_entries() -> Result, Error> { - match get_entries_linux() { - Ok(v) => Ok(v), - Err(e) => Err(Error::ProcFs(format!("{:?}", e))), - } -} - -#[cfg(target_os = "linux")] -pub fn get_entries_linux() -> procfs::ProcResult> { +pub fn get_entries() -> Result> { use procfs::process::FDTarget; use std::collections::HashMap;