Simplify error handling with anyhow
This commit is contained in:
parent
82569743a3
commit
7f8e14384e
6 changed files with 90 additions and 138 deletions
28
Cargo.lock
generated
28
Cargo.lock
generated
|
|
@ -17,6 +17,12 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anyhow"
|
||||||
|
version = "1.0.65"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
|
@ -124,8 +130,10 @@ dependencies = [
|
||||||
name = "fwd"
|
name = "fwd"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"bytes",
|
"bytes",
|
||||||
"procfs",
|
"procfs",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -398,6 +406,26 @@ dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.37"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.37"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.1.44"
|
version = "0.1.44"
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ edition = "2021"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
|
thiserror = "1.0"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
|
||||||
[target.'cfg(target_os="linux")'.dependencies]
|
[target.'cfg(target_os="linux")'.dependencies]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
use crate::Error;
|
use anyhow::Result;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
@ -30,11 +30,11 @@ async fn connection_read<T: AsyncRead + Unpin>(
|
||||||
channel: u64,
|
channel: u64,
|
||||||
read: &mut T,
|
read: &mut T,
|
||||||
writer: &mut mpsc::Sender<Message>,
|
writer: &mut mpsc::Sender<Message>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), tokio::io::Error> {
|
||||||
let result = loop {
|
let result = loop {
|
||||||
let mut buffer = BytesMut::with_capacity(MAX_PACKET);
|
let mut buffer = BytesMut::with_capacity(MAX_PACKET);
|
||||||
if let Err(e) = read.read_buf(&mut buffer).await {
|
if let Err(e) = read.read_buf(&mut buffer).await {
|
||||||
break Err(Error::IO(e));
|
break Err(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if buffer.len() == 0 {
|
if buffer.len() == 0 {
|
||||||
|
|
@ -42,7 +42,9 @@ async fn connection_read<T: AsyncRead + Unpin>(
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await {
|
if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await {
|
||||||
break Err(Error::ConnectionReset);
|
break Err(tokio::io::Error::from(
|
||||||
|
tokio::io::ErrorKind::ConnectionReset,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Flow control here, wait for the packet to be acknowleged so
|
// TODO: Flow control here, wait for the packet to be acknowleged so
|
||||||
|
|
@ -65,11 +67,9 @@ async fn connection_read<T: AsyncRead + Unpin>(
|
||||||
async fn connection_write<T: AsyncWrite + Unpin>(
|
async fn connection_write<T: AsyncWrite + Unpin>(
|
||||||
data: &mut mpsc::Receiver<Bytes>,
|
data: &mut mpsc::Receiver<Bytes>,
|
||||||
write: &mut T,
|
write: &mut T,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
while let Some(buf) = data.recv().await {
|
while let Some(buf) = data.recv().await {
|
||||||
if let Err(e) = write.write_all(&buf[..]).await {
|
write.write_all(&buf[..]).await?;
|
||||||
return Err(Error::IO(e));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
104
src/lib.rs
104
src/lib.rs
|
|
@ -1,3 +1,4 @@
|
||||||
|
use anyhow::{bail, Result};
|
||||||
use connection::ConnectionTable;
|
use connection::ConnectionTable;
|
||||||
use message::{Message, MessageReader, MessageWriter};
|
use message::{Message, MessageReader, MessageWriter};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -12,63 +13,6 @@ mod connection;
|
||||||
mod message;
|
mod message;
|
||||||
mod refresh;
|
mod refresh;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum Error {
|
|
||||||
Protocol,
|
|
||||||
ProtocolVersion,
|
|
||||||
IO(tokio::io::Error),
|
|
||||||
MessageIncomplete,
|
|
||||||
MessageUnknown,
|
|
||||||
MessageCorrupt,
|
|
||||||
ConnectionReset,
|
|
||||||
ProcFs(String),
|
|
||||||
NotSupported,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for Error {
|
|
||||||
fn eq(&self, other: &Error) -> bool {
|
|
||||||
use Error::*;
|
|
||||||
match self {
|
|
||||||
Protocol => match other {
|
|
||||||
Protocol => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
ProtocolVersion => match other {
|
|
||||||
ProtocolVersion => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
IO(s) => match other {
|
|
||||||
IO(o) => s.kind() == o.kind(),
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
MessageIncomplete => match other {
|
|
||||||
MessageIncomplete => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
MessageUnknown => match other {
|
|
||||||
MessageUnknown => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
MessageCorrupt => match other {
|
|
||||||
MessageCorrupt => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
ConnectionReset => match other {
|
|
||||||
ConnectionReset => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
ProcFs(a) => match other {
|
|
||||||
ProcFs(b) => a == b,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
NotSupported => match other {
|
|
||||||
NotSupported => true,
|
|
||||||
_ => false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Write Management
|
// Write Management
|
||||||
|
|
||||||
|
|
@ -92,7 +36,7 @@ impl PartialEq for Error {
|
||||||
async fn pump_write<T: AsyncWrite + Unpin>(
|
async fn pump_write<T: AsyncWrite + Unpin>(
|
||||||
messages: &mut mpsc::Receiver<Message>,
|
messages: &mut mpsc::Receiver<Message>,
|
||||||
writer: &mut MessageWriter<T>,
|
writer: &mut MessageWriter<T>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
while let Some(msg) = messages.recv().await {
|
while let Some(msg) = messages.recv().await {
|
||||||
writer.write(msg).await?;
|
writer.write(msg).await?;
|
||||||
}
|
}
|
||||||
|
|
@ -125,7 +69,7 @@ async fn server_read<T: AsyncRead + Unpin>(
|
||||||
reader: &mut MessageReader<T>,
|
reader: &mut MessageReader<T>,
|
||||||
writer: mpsc::Sender<Message>,
|
writer: mpsc::Sender<Message>,
|
||||||
connections: ConnectionTable,
|
connections: ConnectionTable,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
eprintln!("< Processing packets...");
|
eprintln!("< Processing packets...");
|
||||||
loop {
|
loop {
|
||||||
let message = reader.read().await?;
|
let message = reader.read().await?;
|
||||||
|
|
@ -179,7 +123,7 @@ async fn server_read<T: AsyncRead + Unpin>(
|
||||||
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
reader: &mut MessageReader<Reader>,
|
reader: &mut MessageReader<Reader>,
|
||||||
writer: &mut MessageWriter<Writer>,
|
writer: &mut MessageWriter<Writer>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
let connections = ConnectionTable::new();
|
let connections = ConnectionTable::new();
|
||||||
|
|
||||||
// The first message we send must be an announcement.
|
// The first message we send must be an announcement.
|
||||||
|
|
@ -217,7 +161,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> {
|
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<()> {
|
||||||
// TODO: While we're waiting here we should be echoing everything we read.
|
// TODO: While we're waiting here we should be echoing everything we read.
|
||||||
// We should also be proxying *our* stdin to the processes stdin,
|
// We should also be proxying *our* stdin to the processes stdin,
|
||||||
// and turn that off when we've synchronized. That way we can
|
// and turn that off when we've synchronized. That way we can
|
||||||
|
|
@ -225,11 +169,12 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
|
||||||
eprintln!("> Waiting for synchronization marker...");
|
eprintln!("> Waiting for synchronization marker...");
|
||||||
let mut seen = 0;
|
let mut seen = 0;
|
||||||
while seen < 8 {
|
while seen < 8 {
|
||||||
let byte = match reader.read_u8().await {
|
let byte = reader.read_u8().await?;
|
||||||
Ok(b) => b,
|
if byte == 0 {
|
||||||
Err(e) => return Err(Error::IO(e)),
|
seen += 1;
|
||||||
};
|
} else {
|
||||||
seen = if byte == 0 { seen + 1 } else { 0 };
|
tokio::io::stdout().write_u8(byte).await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -261,19 +206,13 @@ async fn client_listen(
|
||||||
port: u16,
|
port: u16,
|
||||||
writer: mpsc::Sender<Message>,
|
writer: mpsc::Sender<Message>,
|
||||||
connections: ConnectionTable,
|
connections: ConnectionTable,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
loop {
|
loop {
|
||||||
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
|
let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?;
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => return Err(Error::IO(e)),
|
|
||||||
};
|
|
||||||
loop {
|
loop {
|
||||||
// The second item contains the IP and port of the new
|
// The second item contains the IP and port of the new
|
||||||
// connection, but we don't care.
|
// connection, but we don't care.
|
||||||
let (mut socket, _) = match listener.accept().await {
|
let (mut socket, _) = listener.accept().await?;
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => return Err(Error::IO(e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (writer, connections) = (writer.clone(), connections.clone());
|
let (writer, connections) = (writer.clone(), connections.clone());
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|
@ -287,7 +226,7 @@ async fn client_read<T: AsyncRead + Unpin>(
|
||||||
reader: &mut MessageReader<T>,
|
reader: &mut MessageReader<T>,
|
||||||
writer: mpsc::Sender<Message>,
|
writer: mpsc::Sender<Message>,
|
||||||
connections: ConnectionTable,
|
connections: ConnectionTable,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
|
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
|
||||||
|
|
||||||
eprintln!("> Processing packets...");
|
eprintln!("> Processing packets...");
|
||||||
|
|
@ -362,14 +301,14 @@ async fn client_read<T: AsyncRead + Unpin>(
|
||||||
async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
reader: &mut MessageReader<Reader>,
|
reader: &mut MessageReader<Reader>,
|
||||||
writer: &mut MessageWriter<Writer>,
|
writer: &mut MessageWriter<Writer>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<()> {
|
||||||
// Wait for the server's announcement.
|
// Wait for the server's announcement.
|
||||||
if let Message::Hello(major, minor, _) = reader.read().await? {
|
if let Message::Hello(major, minor, _) = reader.read().await? {
|
||||||
if major != 0 || minor > 1 {
|
if major != 0 || minor > 1 {
|
||||||
return Err(Error::ProtocolVersion);
|
bail!("Unsupported remote protocol version {}.{}", major, minor);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::Protocol);
|
bail!("Expected a hello message from the remote server");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Kick things off with a listing of the ports...
|
// Kick things off with a listing of the ports...
|
||||||
|
|
@ -436,16 +375,13 @@ pub async fn run_server() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, Error> {
|
async fn spawn_ssh(server: &str) -> Result<tokio::process::Child, std::io::Error> {
|
||||||
let mut cmd = process::Command::new("ssh");
|
let mut cmd = process::Command::new("ssh");
|
||||||
cmd.arg("-T").arg(server).arg("fwd").arg("--server");
|
cmd.arg("-T").arg(server).arg("fwd").arg("--server");
|
||||||
|
|
||||||
cmd.stdout(std::process::Stdio::piped());
|
cmd.stdout(std::process::Stdio::piped());
|
||||||
cmd.stdin(std::process::Stdio::piped());
|
cmd.stdin(std::process::Stdio::piped());
|
||||||
match cmd.spawn() {
|
cmd.spawn()
|
||||||
Ok(t) => Ok(t),
|
|
||||||
Err(e) => Err(Error::IO(e)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_client(remote: &str) {
|
pub async fn run_client(remote: &str) {
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,20 @@
|
||||||
use crate::Error;
|
use anyhow::Result;
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
use thiserror::Error;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Messages
|
// Messages
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum MessageError {
|
||||||
|
#[error("Message type unknown: {0}")]
|
||||||
|
Unknown(u8),
|
||||||
|
#[error("Message incomplete")]
|
||||||
|
Incomplete,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct PortDesc {
|
pub struct PortDesc {
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
|
@ -83,7 +92,7 @@ impl Message {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message, Error> {
|
pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message> {
|
||||||
use Message::*;
|
use Message::*;
|
||||||
match get_u8(cursor)? {
|
match get_u8(cursor)? {
|
||||||
0x00 => Ok(Ping),
|
0x00 => Ok(Ping),
|
||||||
|
|
@ -127,48 +136,45 @@ impl Message {
|
||||||
let data = get_bytes(cursor, length.into())?;
|
let data = get_bytes(cursor, length.into())?;
|
||||||
Ok(Data(channel, data))
|
Ok(Data(channel, data))
|
||||||
}
|
}
|
||||||
_ => Err(Error::MessageUnknown),
|
b => Err(MessageError::Unknown(b).into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, Error> {
|
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> {
|
||||||
if !cursor.has_remaining() {
|
if !cursor.has_remaining() {
|
||||||
return Err(Error::MessageIncomplete);
|
return Err(MessageError::Incomplete);
|
||||||
}
|
}
|
||||||
Ok(cursor.get_u8())
|
Ok(cursor.get_u8())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, Error> {
|
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> {
|
||||||
if cursor.remaining() < 2 {
|
if cursor.remaining() < 2 {
|
||||||
return Err(Error::MessageIncomplete);
|
return Err(MessageError::Incomplete);
|
||||||
}
|
}
|
||||||
Ok(cursor.get_u16())
|
Ok(cursor.get_u16())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, Error> {
|
fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, MessageError> {
|
||||||
if cursor.remaining() < 8 {
|
if cursor.remaining() < 8 {
|
||||||
return Err(Error::MessageIncomplete);
|
return Err(MessageError::Incomplete);
|
||||||
}
|
}
|
||||||
Ok(cursor.get_u64())
|
Ok(cursor.get_u64())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, Error> {
|
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> {
|
||||||
if cursor.remaining() < length {
|
if cursor.remaining() < length {
|
||||||
return Err(Error::MessageIncomplete);
|
return Err(MessageError::Incomplete);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(cursor.copy_to_bytes(length))
|
Ok(cursor.copy_to_bytes(length))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String, Error> {
|
fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String> {
|
||||||
let length = get_u16(cursor)?;
|
let length = get_u16(cursor)?;
|
||||||
|
|
||||||
let data = get_bytes(cursor, length.into())?;
|
let data = get_bytes(cursor, length.into())?;
|
||||||
match std::str::from_utf8(&data[..]) {
|
Ok(std::str::from_utf8(&data[..])?.to_owned())
|
||||||
Ok(s) => Ok(s.to_owned()),
|
|
||||||
Err(_) => return Err(Error::MessageCorrupt),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn slice_up_to(s: &str, max_len: usize) -> &str {
|
fn slice_up_to(s: &str, max_len: usize) -> &str {
|
||||||
|
|
@ -198,13 +204,7 @@ impl<T: AsyncWrite + Unpin> MessageWriter<T> {
|
||||||
pub fn new(writer: T) -> MessageWriter<T> {
|
pub fn new(writer: T) -> MessageWriter<T> {
|
||||||
MessageWriter { writer }
|
MessageWriter { writer }
|
||||||
}
|
}
|
||||||
pub async fn write(self: &mut Self, msg: Message) -> Result<(), Error> {
|
pub async fn write(self: &mut Self, msg: Message) -> Result<()> {
|
||||||
match self.write_impl(msg).await {
|
|
||||||
Err(e) => Err(Error::IO(e)),
|
|
||||||
Ok(ok) => Ok(ok),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async fn write_impl(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> {
|
|
||||||
// TODO: Optimize buffer usage please this is bad
|
// TODO: Optimize buffer usage please this is bad
|
||||||
// eprintln!("? {:?}", msg);
|
// eprintln!("? {:?}", msg);
|
||||||
let mut buffer = msg.encode();
|
let mut buffer = msg.encode();
|
||||||
|
|
@ -225,16 +225,10 @@ impl<T: AsyncRead + Unpin> MessageReader<T> {
|
||||||
pub fn new(reader: T) -> MessageReader<T> {
|
pub fn new(reader: T) -> MessageReader<T> {
|
||||||
MessageReader { reader }
|
MessageReader { reader }
|
||||||
}
|
}
|
||||||
pub async fn read(self: &mut Self) -> Result<Message, Error> {
|
pub async fn read(self: &mut Self) -> Result<Message> {
|
||||||
let frame_length = match self.reader.read_u32().await {
|
let frame_length = self.reader.read_u32().await?;
|
||||||
Ok(l) => l,
|
|
||||||
Err(e) => return Err(Error::IO(e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap());
|
let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap());
|
||||||
if let Err(e) = self.reader.read_buf(&mut data).await {
|
self.reader.read_buf(&mut data).await?;
|
||||||
return Err(Error::IO(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut cursor = Cursor::new(&data[..]);
|
let mut cursor = Cursor::new(&data[..]);
|
||||||
Message::decode(&mut cursor)
|
Message::decode(&mut cursor)
|
||||||
|
|
@ -250,8 +244,8 @@ mod message_tests {
|
||||||
fn assert_round_trip(message: Message) {
|
fn assert_round_trip(message: Message) {
|
||||||
let encoded = message.encode();
|
let encoded = message.encode();
|
||||||
let mut cursor = std::io::Cursor::new(&encoded[..]);
|
let mut cursor = std::io::Cursor::new(&encoded[..]);
|
||||||
let result = Message::decode(&mut cursor);
|
let result = Message::decode(&mut cursor).unwrap();
|
||||||
assert_eq!(Ok(message.clone()), result);
|
assert_eq!(message.clone(), result);
|
||||||
|
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,13 @@
|
||||||
use crate::message::PortDesc;
|
use crate::message::PortDesc;
|
||||||
use crate::Error;
|
use anyhow::{bail, Result};
|
||||||
|
|
||||||
#[cfg(not(target_os = "linux"))]
|
#[cfg(not(target_os = "linux"))]
|
||||||
pub fn get_entries() -> Result<Vec<PortDesc>, Error> {
|
pub fn get_entries() -> Result<Vec<PortDesc>> {
|
||||||
Err(Error::NotSupported)
|
bail!("Not supported on this operating system");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
pub fn get_entries() -> Result<Vec<PortDesc>, Error> {
|
pub fn get_entries() -> Result<Vec<PortDesc>> {
|
||||||
match get_entries_linux() {
|
|
||||||
Ok(v) => Ok(v),
|
|
||||||
Err(e) => Err(Error::ProcFs(format!("{:?}", e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
pub fn get_entries_linux() -> procfs::ProcResult<Vec<PortDesc>> {
|
|
||||||
use procfs::process::FDTarget;
|
use procfs::process::FDTarget;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue