Simplify error handling with anyhow
This commit is contained in:
parent
82569743a3
commit
7f8e14384e
6 changed files with 90 additions and 138 deletions
104
src/lib.rs
104
src/lib.rs
|
|
@ -1,3 +1,4 @@
|
|||
use anyhow::{bail, Result};
|
||||
use connection::ConnectionTable;
|
||||
use message::{Message, MessageReader, MessageWriter};
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -12,63 +13,6 @@ mod connection;
|
|||
mod message;
|
||||
mod refresh;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Protocol,
|
||||
ProtocolVersion,
|
||||
IO(tokio::io::Error),
|
||||
MessageIncomplete,
|
||||
MessageUnknown,
|
||||
MessageCorrupt,
|
||||
ConnectionReset,
|
||||
ProcFs(String),
|
||||
NotSupported,
|
||||
}
|
||||
|
||||
impl PartialEq for Error {
|
||||
fn eq(&self, other: &Error) -> bool {
|
||||
use Error::*;
|
||||
match self {
|
||||
Protocol => match other {
|
||||
Protocol => true,
|
||||
_ => false,
|
||||
},
|
||||
ProtocolVersion => match other {
|
||||
ProtocolVersion => true,
|
||||
_ => false,
|
||||
},
|
||||
IO(s) => match other {
|
||||
IO(o) => s.kind() == o.kind(),
|
||||
_ => false,
|
||||
},
|
||||
MessageIncomplete => match other {
|
||||
MessageIncomplete => true,
|
||||
_ => false,
|
||||
},
|
||||
MessageUnknown => match other {
|
||||
MessageUnknown => true,
|
||||
_ => false,
|
||||
},
|
||||
MessageCorrupt => match other {
|
||||
MessageCorrupt => true,
|
||||
_ => false,
|
||||
},
|
||||
ConnectionReset => match other {
|
||||
ConnectionReset => true,
|
||||
_ => false,
|
||||
},
|
||||
ProcFs(a) => match other {
|
||||
ProcFs(b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
NotSupported => match other {
|
||||
NotSupported => true,
|
||||
_ => false,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Write Management
|
||||
|
||||
|
|
@ -92,7 +36,7 @@ impl PartialEq for Error {
|
|||
async fn pump_write<T: AsyncWrite + Unpin>(
|
||||
messages: &mut mpsc::Receiver<Message>,
|
||||
writer: &mut MessageWriter<T>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
while let Some(msg) = messages.recv().await {
|
||||
writer.write(msg).await?;
|
||||
}
|
||||
|
|
@ -125,7 +69,7 @@ async fn server_read<T: AsyncRead + Unpin>(
|
|||
reader: &mut MessageReader<T>,
|
||||
writer: mpsc::Sender<Message>,
|
||||
connections: ConnectionTable,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
eprintln!("< Processing packets...");
|
||||
loop {
|
||||
let message = reader.read().await?;
|
||||
|
|
@ -179,7 +123,7 @@ async fn server_read<T: AsyncRead + Unpin>(
|
|||
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||
reader: &mut MessageReader<Reader>,
|
||||
writer: &mut MessageWriter<Writer>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
let connections = ConnectionTable::new();
|
||||
|
||||
// The first message we send must be an announcement.
|
||||
|
|
@ -217,7 +161,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
|||
}
|
||||
}
|
||||
|
||||
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> {
|
||||
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<()> {
|
||||
// TODO: While we're waiting here we should be echoing everything we read.
|
||||
// We should also be proxying *our* stdin to the processes stdin,
|
||||
// and turn that off when we've synchronized. That way we can
|
||||
|
|
@ -225,11 +169,12 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
|
|||
eprintln!("> Waiting for synchronization marker...");
|
||||
let mut seen = 0;
|
||||
while seen < 8 {
|
||||
let byte = match reader.read_u8().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => return Err(Error::IO(e)),
|
||||
};
|
||||
seen = if byte == 0 { seen + 1 } else { 0 };
|
||||
let byte = reader.read_u8().await?;
|
||||
if byte == 0 {
|
||||
seen += 1;
|
||||
} else {
|
||||
tokio::io::stdout().write_u8(byte).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -261,19 +206,13 @@ async fn client_listen(
|
|||
port: u16,
|
||||
writer: mpsc::Sender<Message>,
|
||||
connections: ConnectionTable,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return Err(Error::IO(e)),
|
||||
};
|
||||
let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?;
|
||||
loop {
|
||||
// The second item contains the IP and port of the new
|
||||
// connection, but we don't care.
|
||||
let (mut socket, _) = match listener.accept().await {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Err(Error::IO(e)),
|
||||
};
|
||||
let (mut socket, _) = listener.accept().await?;
|
||||
|
||||
let (writer, connections) = (writer.clone(), connections.clone());
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -287,7 +226,7 @@ async fn client_read<T: AsyncRead + Unpin>(
|
|||
reader: &mut MessageReader<T>,
|
||||
writer: mpsc::Sender<Message>,
|
||||
connections: ConnectionTable,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
|
||||
|
||||
eprintln!("> Processing packets...");
|
||||
|
|
@ -362,14 +301,14 @@ async fn client_read<T: AsyncRead + Unpin>(
|
|||
async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||
reader: &mut MessageReader<Reader>,
|
||||
writer: &mut MessageWriter<Writer>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<()> {
|
||||
// Wait for the server's announcement.
|
||||
if let Message::Hello(major, minor, _) = reader.read().await? {
|
||||
if major != 0 || minor > 1 {
|
||||
return Err(Error::ProtocolVersion);
|
||||
bail!("Unsupported remote protocol version {}.{}", major, minor);
|
||||
}
|
||||
} else {
|
||||
return Err(Error::Protocol);
|
||||
bail!("Expected a hello message from the remote server");
|
||||
}
|
||||
|
||||
// Kick things off with a listing of the ports...
|
||||
|
|
@ -436,16 +375,13 @@ pub async fn run_server() {
|
|||
}
|
||||
}
|
||||
|
||||
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, Error> {
|
||||
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, std::io::Error> {
|
||||
let mut cmd = process::Command::new("ssh");
|
||||
cmd.arg("-T").arg(server).arg("fwd").arg("--server");
|
||||
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
match cmd.spawn() {
|
||||
Ok(t) => Ok(t),
|
||||
Err(e) => Err(Error::IO(e)),
|
||||
}
|
||||
cmd.spawn()
|
||||
}
|
||||
|
||||
pub async fn run_client(remote: &str) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue