Simplify error handling with anyhow
This commit is contained in:
parent
82569743a3
commit
7f8e14384e
6 changed files with 90 additions and 138 deletions
|
|
@ -1,11 +1,20 @@
|
|||
use crate::Error;
|
||||
use anyhow::Result;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Messages
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MessageError {
|
||||
#[error("Message type unknown: {0}")]
|
||||
Unknown(u8),
|
||||
#[error("Message incomplete")]
|
||||
Incomplete,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct PortDesc {
|
||||
pub port: u16,
|
||||
|
|
@ -83,7 +92,7 @@ impl Message {
|
|||
};
|
||||
}
|
||||
|
||||
pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message, Error> {
|
||||
pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result<Message> {
|
||||
use Message::*;
|
||||
match get_u8(cursor)? {
|
||||
0x00 => Ok(Ping),
|
||||
|
|
@ -127,48 +136,45 @@ impl Message {
|
|||
let data = get_bytes(cursor, length.into())?;
|
||||
Ok(Data(channel, data))
|
||||
}
|
||||
_ => Err(Error::MessageUnknown),
|
||||
b => Err(MessageError::Unknown(b).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, Error> {
|
||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> {
|
||||
if !cursor.has_remaining() {
|
||||
return Err(Error::MessageIncomplete);
|
||||
return Err(MessageError::Incomplete);
|
||||
}
|
||||
Ok(cursor.get_u8())
|
||||
}
|
||||
|
||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, Error> {
|
||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> {
|
||||
if cursor.remaining() < 2 {
|
||||
return Err(Error::MessageIncomplete);
|
||||
return Err(MessageError::Incomplete);
|
||||
}
|
||||
Ok(cursor.get_u16())
|
||||
}
|
||||
|
||||
fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, Error> {
|
||||
fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result<u64, MessageError> {
|
||||
if cursor.remaining() < 8 {
|
||||
return Err(Error::MessageIncomplete);
|
||||
return Err(MessageError::Incomplete);
|
||||
}
|
||||
Ok(cursor.get_u64())
|
||||
}
|
||||
|
||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, Error> {
|
||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> {
|
||||
if cursor.remaining() < length {
|
||||
return Err(Error::MessageIncomplete);
|
||||
return Err(MessageError::Incomplete);
|
||||
}
|
||||
|
||||
Ok(cursor.copy_to_bytes(length))
|
||||
}
|
||||
|
||||
fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String, Error> {
|
||||
fn get_string(cursor: &mut Cursor<&[u8]>) -> Result<String> {
|
||||
let length = get_u16(cursor)?;
|
||||
|
||||
let data = get_bytes(cursor, length.into())?;
|
||||
match std::str::from_utf8(&data[..]) {
|
||||
Ok(s) => Ok(s.to_owned()),
|
||||
Err(_) => return Err(Error::MessageCorrupt),
|
||||
}
|
||||
Ok(std::str::from_utf8(&data[..])?.to_owned())
|
||||
}
|
||||
|
||||
fn slice_up_to(s: &str, max_len: usize) -> &str {
|
||||
|
|
@ -198,13 +204,7 @@ impl<T: AsyncWrite + Unpin> MessageWriter<T> {
|
|||
pub fn new(writer: T) -> MessageWriter<T> {
|
||||
MessageWriter { writer }
|
||||
}
|
||||
pub async fn write(self: &mut Self, msg: Message) -> Result<(), Error> {
|
||||
match self.write_impl(msg).await {
|
||||
Err(e) => Err(Error::IO(e)),
|
||||
Ok(ok) => Ok(ok),
|
||||
}
|
||||
}
|
||||
async fn write_impl(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> {
|
||||
pub async fn write(self: &mut Self, msg: Message) -> Result<()> {
|
||||
// TODO: Optimize buffer usage please this is bad
|
||||
// eprintln!("? {:?}", msg);
|
||||
let mut buffer = msg.encode();
|
||||
|
|
@ -225,16 +225,10 @@ impl<T: AsyncRead + Unpin> MessageReader<T> {
|
|||
pub fn new(reader: T) -> MessageReader<T> {
|
||||
MessageReader { reader }
|
||||
}
|
||||
pub async fn read(self: &mut Self) -> Result<Message, Error> {
|
||||
let frame_length = match self.reader.read_u32().await {
|
||||
Ok(l) => l,
|
||||
Err(e) => return Err(Error::IO(e)),
|
||||
};
|
||||
|
||||
pub async fn read(self: &mut Self) -> Result<Message> {
|
||||
let frame_length = self.reader.read_u32().await?;
|
||||
let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap());
|
||||
if let Err(e) = self.reader.read_buf(&mut data).await {
|
||||
return Err(Error::IO(e));
|
||||
}
|
||||
self.reader.read_buf(&mut data).await?;
|
||||
|
||||
let mut cursor = Cursor::new(&data[..]);
|
||||
Message::decode(&mut cursor)
|
||||
|
|
@ -250,8 +244,8 @@ mod message_tests {
|
|||
fn assert_round_trip(message: Message) {
|
||||
let encoded = message.encode();
|
||||
let mut cursor = std::io::Cursor::new(&encoded[..]);
|
||||
let result = Message::decode(&mut cursor);
|
||||
assert_eq!(Ok(message.clone()), result);
|
||||
let result = Message::decode(&mut cursor).unwrap();
|
||||
assert_eq!(message.clone(), result);
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue