Simplify error handling with anyhow

This commit is contained in:
John Doty 2022-10-09 08:21:03 -07:00
parent 82569743a3
commit 7f8e14384e
6 changed files with 90 additions and 138 deletions

View file

@ -1,3 +1,4 @@
use anyhow::{bail, Result};
use connection::ConnectionTable;
use message::{Message, MessageReader, MessageWriter};
use std::collections::HashMap;
@ -12,63 +13,6 @@ mod connection;
mod message;
mod refresh;
#[derive(Debug)]
pub enum Error {
Protocol,
ProtocolVersion,
IO(tokio::io::Error),
MessageIncomplete,
MessageUnknown,
MessageCorrupt,
ConnectionReset,
ProcFs(String),
NotSupported,
}
impl PartialEq for Error {
fn eq(&self, other: &Error) -> bool {
use Error::*;
match self {
Protocol => match other {
Protocol => true,
_ => false,
},
ProtocolVersion => match other {
ProtocolVersion => true,
_ => false,
},
IO(s) => match other {
IO(o) => s.kind() == o.kind(),
_ => false,
},
MessageIncomplete => match other {
MessageIncomplete => true,
_ => false,
},
MessageUnknown => match other {
MessageUnknown => true,
_ => false,
},
MessageCorrupt => match other {
MessageCorrupt => true,
_ => false,
},
ConnectionReset => match other {
ConnectionReset => true,
_ => false,
},
ProcFs(a) => match other {
ProcFs(b) => a == b,
_ => false,
},
NotSupported => match other {
NotSupported => true,
_ => false,
},
}
}
}
// ----------------------------------------------------------------------------
// Write Management
@ -92,7 +36,7 @@ impl PartialEq for Error {
async fn pump_write<T: AsyncWrite + Unpin>(
messages: &mut mpsc::Receiver<Message>,
writer: &mut MessageWriter<T>,
) -> Result<(), Error> {
) -> Result<()> {
while let Some(msg) = messages.recv().await {
writer.write(msg).await?;
}
@ -125,7 +69,7 @@ async fn server_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>,
connections: ConnectionTable,
) -> Result<(), Error> {
) -> Result<()> {
eprintln!("< Processing packets...");
loop {
let message = reader.read().await?;
@ -179,7 +123,7 @@ async fn server_read<T: AsyncRead + Unpin>(
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> {
) -> Result<()> {
let connections = ConnectionTable::new();
// The first message we send must be an announcement.
@ -217,7 +161,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
}
}
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> {
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<()> {
// TODO: While we're waiting here we should be echoing everything we read.
// We should also be proxying *our* stdin to the processes stdin,
// and turn that off when we've synchronized. That way we can
@ -225,11 +169,12 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
eprintln!("> Waiting for synchronization marker...");
let mut seen = 0;
while seen < 8 {
let byte = match reader.read_u8().await {
Ok(b) => b,
Err(e) => return Err(Error::IO(e)),
};
seen = if byte == 0 { seen + 1 } else { 0 };
let byte = reader.read_u8().await?;
if byte == 0 {
seen += 1;
} else {
tokio::io::stdout().write_u8(byte).await?;
}
}
Ok(())
}
@ -261,19 +206,13 @@ async fn client_listen(
port: u16,
writer: mpsc::Sender<Message>,
connections: ConnectionTable,
) -> Result<(), Error> {
) -> Result<()> {
loop {
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
Ok(t) => t,
Err(e) => return Err(Error::IO(e)),
};
let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?;
loop {
// The second item contains the IP and port of the new
// connection, but we don't care.
let (mut socket, _) = match listener.accept().await {
Ok(s) => s,
Err(e) => return Err(Error::IO(e)),
};
let (mut socket, _) = listener.accept().await?;
let (writer, connections) = (writer.clone(), connections.clone());
tokio::spawn(async move {
@ -287,7 +226,7 @@ async fn client_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>,
connections: ConnectionTable,
) -> Result<(), Error> {
) -> Result<()> {
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
eprintln!("> Processing packets...");
@ -362,14 +301,14 @@ async fn client_read<T: AsyncRead + Unpin>(
async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> {
) -> Result<()> {
// Wait for the server's announcement.
if let Message::Hello(major, minor, _) = reader.read().await? {
if major != 0 || minor > 1 {
return Err(Error::ProtocolVersion);
bail!("Unsupported remote protocol version {}.{}", major, minor);
}
} else {
return Err(Error::Protocol);
bail!("Expected a hello message from the remote server");
}
// Kick things off with a listing of the ports...
@ -436,16 +375,13 @@ pub async fn run_server() {
}
}
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, Error> {
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, std::io::Error> {
let mut cmd = process::Command::new("ssh");
cmd.arg("-T").arg(server).arg("fwd").arg("--server");
cmd.stdout(std::process::Stdio::piped());
cmd.stdin(std::process::Stdio::piped());
match cmd.spawn() {
Ok(t) => Ok(t),
Err(e) => Err(Error::IO(e)),
}
cmd.spawn()
}
pub async fn run_client(remote: &str) {