diff --git a/src/lib.rs b/src/lib.rs index 1e7a3b6..8ea6a0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,43 +1,28 @@ use bytes::{Bytes, BytesMut}; use std::collections::HashMap; -use std::io::Cursor; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::{Arc, Mutex}; -use tokio::io::{ - AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, Error, ErrorKind, -}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::process; use tokio::sync::mpsc; use tokio::sync::oneshot; +mod error; mod message; mod refresh; -use message::Message; +use message::{Message, MessageReader, MessageWriter}; -// ---------------------------------------------------------------------------- -// Message Writing - -struct MessageWriter { - writer: T, -} - -impl MessageWriter { - fn new(writer: T) -> MessageWriter { - MessageWriter { writer } - } - async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { - // TODO: Optimize buffer usage please this is bad - // eprintln!("? {:?}", msg); - let mut buffer = msg.encode(); - self.writer - .write_u32(buffer.len().try_into().expect("Message too large")) - .await?; - self.writer.write_buf(&mut buffer).await?; - self.writer.flush().await?; - Ok(()) - } +#[derive(Debug)] +pub enum Error { + Protocol, + ProtocolVersion, + IO(tokio::io::Error), + MessageIncomplete, + MessageUnknown, + MessageCorrupt, + ConnectionReset, } async fn pump_write( @@ -63,14 +48,14 @@ async fn connection_read( let result = loop { let mut buffer = BytesMut::with_capacity(64 * 1024); if let Err(e) = read.read_buf(&mut buffer).await { - break Err(e); + break Err(Error::IO(e)); } if buffer.len() == 0 { break Ok(()); } if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { - break Err(Error::from(ErrorKind::ConnectionReset)); + break Err(Error::ConnectionReset); } // TODO: Flow control here, wait for the packet to be acknowleged so @@ -91,7 +76,9 @@ async fn connection_write( write: &mut T, ) -> Result<(), Error> { while let Some(buf) = data.recv().await { - write.write_all(&buf[..]).await?; + if let Err(e) = write.write_all(&buf[..]).await { + return Err(Error::IO(e)); + } } Ok(()) } @@ -183,28 +170,16 @@ async fn server_handle_connection( eprintln!("< Done server!"); } } - - // Wrong! - _ = writer.send(Message::Closed(channel)); } async fn server_read( - reader: &mut T, + reader: &mut MessageReader, writer: mpsc::Sender, connections: ServerConnectionTable, ) -> Result<(), Error> { eprintln!("< Processing packets..."); loop { - let frame_length = reader.read_u32().await?; - - let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); - reader.read_buf(&mut data).await?; - - let mut cursor = Cursor::new(&data[..]); - let message = match Message::decode(&mut cursor) { - Ok(msg) => msg, - Err(_) => return Err(Error::from(ErrorKind::InvalidData)), - }; + let message = reader.read().await?; use Message::*; match message { @@ -253,11 +228,14 @@ async fn server_read( } async fn server_main( - reader: &mut Reader, + reader: &mut MessageReader, writer: &mut MessageWriter, ) -> Result<(), Error> { let connections = ServerConnectionTable::new(); + // The first message we send must be an announcement. + writer.write(Message::Hello(0, 1, vec![])).await?; + // Jump into it... let (msg_sender, mut msg_receiver) = mpsc::channel(32); let writing = pump_write(&mut msg_receiver, writer); @@ -296,14 +274,20 @@ async fn spawn_ssh(server: &str) -> Result { cmd.stdout(std::process::Stdio::piped()); cmd.stdin(std::process::Stdio::piped()); - cmd.spawn() + match cmd.spawn() { + Ok(t) => Ok(t), + Err(e) => Err(Error::IO(e)), + } } async fn client_sync(reader: &mut T) -> Result<(), Error> { eprintln!("> Waiting for synchronization marker..."); let mut seen = 0; while seen < 8 { - let byte = reader.read_u8().await?; + let byte = match reader.read_u8().await { + Ok(b) => b, + Err(e) => return Err(Error::IO(e)), + }; seen = if byte == 0 { seen + 1 } else { 0 }; } Ok(()) @@ -413,11 +397,17 @@ async fn client_listen( connections: ClientConnectionTable, ) -> Result<(), Error> { loop { - let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; + let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { + Ok(t) => t, + Err(e) => return Err(Error::IO(e)), + }; loop { // The second item contains the IP and port of the new // connection, but we don't care. - let (mut socket, _) = listener.accept().await?; + let (mut socket, _) = match listener.accept().await { + Ok(s) => s, + Err(e) => return Err(Error::IO(e)), + }; let (writer, connections) = (writer.clone(), connections.clone()); tokio::spawn(async move { @@ -428,7 +418,7 @@ async fn client_listen( } async fn client_read( - reader: &mut T, + reader: &mut MessageReader, writer: mpsc::Sender, connections: ClientConnectionTable, ) -> Result<(), Error> { @@ -436,16 +426,7 @@ async fn client_read( eprintln!("> Processing packets..."); loop { - let frame_length = reader.read_u32().await?; - - let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); - reader.read_buf(&mut data).await?; - - let mut cursor = Cursor::new(&data[..]); - let message = match Message::decode(&mut cursor) { - Ok(msg) => msg, - Err(_) => return Err(Error::from(ErrorKind::InvalidData)), - }; + let message = reader.read().await?; use Message::*; match message { @@ -513,14 +494,19 @@ async fn client_read( } async fn client_main( - reader: &mut Reader, + reader: &mut MessageReader, writer: &mut MessageWriter, ) -> Result<(), Error> { - // First synchronize; we're looking for the 8-zero marker that is the 64b sync marker. - // This helps us skip garbage like any kind of MOTD or whatnot. - client_sync(reader).await?; + // Wait for the server's announcement. + if let Message::Hello(major, minor, _) = reader.read().await? { + if major != 0 || minor > 1 { + return Err(Error::ProtocolVersion); + } + } else { + return Err(Error::Protocol); + } - // Now kick things off with a listing of the ports... + // Kick things off with a listing of the ports... eprintln!("> Sending initial list command..."); writer.write(Message::Refresh).await?; @@ -561,21 +547,24 @@ async fn client_main( ///// pub async fn run_server() { - let mut reader = BufReader::new(tokio::io::stdin()); + let reader = BufReader::new(tokio::io::stdin()); let mut writer = BufWriter::new(tokio::io::stdout()); - // Write the marker. + // Write the 8-byte synchronization marker. eprintln!("< Writing marker..."); writer .write_u64(0x00_00_00_00_00_00_00_00) .await .expect("Error writing marker"); - writer.flush().await.expect("Error flushing buffer"); + if let Err(e) = writer.flush().await { + eprintln!("Error writing sync marker: {:?}", e); + return; + } eprintln!("< Done!"); let mut writer = MessageWriter::new(writer); - + let mut reader = MessageReader::new(reader); if let Err(e) = server_main(&mut reader, &mut writer).await { eprintln!("Error: {:?}", e); } @@ -599,6 +588,12 @@ pub async fn run_client(remote: &str) { .expect("child did not have a handle to stdout"), ); + if let Err(e) = client_sync(&mut reader).await { + eprintln!("Error synchronizing: {:?}", e); + return; + } + + let mut reader = MessageReader::new(reader); if let Err(e) = client_main(&mut reader, &mut writer).await { eprintln!("Error: {:?}", e); } diff --git a/src/message.rs b/src/message.rs index 95be271..4d0d8fc 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,12 +1,7 @@ +use crate::Error; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::io::Cursor; - -#[derive(Debug, PartialEq)] -pub enum MessageError { - Incomplete, - UnknownMessage, - Corrupt, -} +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Debug, PartialEq, Clone)] pub struct PortDesc { @@ -16,122 +11,120 @@ pub struct PortDesc { #[derive(Debug, PartialEq)] pub enum Message { - Ping, - Connect(u64, u16), // Request to connect on a port from client to server. - Connected(u64), // Sucessfully connected from server to client. - Close(u64), // Request to close connection on either end. - // Abort(u64), // Notify of close from server to client. - Closed(u64), // Response to Close or Abort. - Refresh, // Request to refresh list of ports from client. - Ports(Vec), // List of available ports from server to client. - Data(u64, Bytes), // Transmit data. + Ping, // Ignored on both sides, can be used to test connection. + Hello(u8, u8, Vec), // Server info announcement: major version, minor version, headers. + Connect(u64, u16), // Request to connect on a port from client to server. + Connected(u64), // Sucessfully connected from server to client. + Close(u64), // Notify that one or the other end of a channel is closed. + Refresh, // Request to refresh list of ports from client. + Ports(Vec), // List of available ports from server to client. + Data(u64, Bytes), // Transmit data on a channel. } impl Message { pub fn encode(self: &Message) -> BytesMut { - use Message::*; let mut result = BytesMut::new(); + self.encode_buf(&mut result); + result + } + + pub fn encode_buf(self: &Message, result: &mut T) { + use Message::*; match self { Ping => { result.put_u8(0x00); } - Connect(channel, port) => { + Hello(major, minor, details) => { result.put_u8(0x01); + result.put_u8(*major); + result.put_u8(*minor); + result.put_u16(details.len().try_into().expect("Too many details")); + for detail in details { + put_string(result, detail); + } + } + Connect(channel, port) => { + result.put_u8(0x02); result.put_u64(*channel); result.put_u16(*port); } Connected(channel) => { - result.put_u8(0x02); - result.put_u64(*channel); - } - Close(channel) => { result.put_u8(0x03); result.put_u64(*channel); } - // Abort(channel) => { - // result.put_u8(0x04); - // result.put_u64(*channel); - // } - Closed(channel) => { - result.put_u8(0x05); + Close(channel) => { + result.put_u8(0x04); result.put_u64(*channel); } Refresh => { - result.put_u8(0x06); + result.put_u8(0x05); } Ports(ports) => { - result.put_u8(0x07); + result.put_u8(0x06); result.put_u16(ports.len().try_into().expect("Too many ports")); for port in ports { result.put_u16(port.port); + // Port descriptions can be long, let's make sure they're not. let sliced = slice_up_to(&port.desc, u16::max_value().into()); - result.put_u16(sliced.len().try_into().unwrap()); - result.put_slice(sliced.as_bytes()); + put_string(result, sliced); } } Data(channel, bytes) => { - result.put_u8(0x08); + result.put_u8(0x07); result.put_u64(*channel); result.put_u16(bytes.len().try_into().expect("Payload too big")); - result.put_slice(bytes); // I hate that this copies. We should make this an async write probably. + result.put_slice(bytes); // I hate that this copies. We should make this an async write probably, maybe? } }; - result } - pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result { + pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result { use Message::*; match get_u8(cursor)? { 0x00 => Ok(Ping), 0x01 => { + let major = get_u8(cursor)?; + let minor = get_u8(cursor)?; + let count = get_u16(cursor)?; + let mut details = Vec::with_capacity(count.into()); + for _ in 0..count { + details.push(get_string(cursor)?); + } + Ok(Hello(major, minor, details)) + } + 0x02 => { let channel = get_u64(cursor)?; let port = get_u16(cursor)?; Ok(Connect(channel, port)) } - 0x02 => { + 0x03 => { let channel = get_u64(cursor)?; Ok(Connected(channel)) } - 0x03 => { + 0x04 => { let channel = get_u64(cursor)?; Ok(Close(channel)) } - // 0x04 => { - // let channel = get_u64(cursor)?; - // Ok(Abort(channel)) - // } - 0x05 => { - let channel = get_u64(cursor)?; - Ok(Closed(channel)) - } - 0x06 => Ok(Refresh), - 0x07 => { + 0x05 => Ok(Refresh), + 0x06 => { let count = get_u16(cursor)?; - - let mut ports = Vec::new(); + let mut ports = Vec::with_capacity(count.into()); for _ in 0..count { let port = get_u16(cursor)?; - let length = get_u16(cursor)?; - - let data = get_bytes(cursor, length.into())?; - let desc = match std::str::from_utf8(&data[..]) { - Ok(s) => s.to_owned(), - Err(_) => return Err(MessageError::Corrupt), - }; - + let desc = get_string(cursor)?; ports.push(PortDesc { port, desc }); } Ok(Ports(ports)) } - 0x08 => { + 0x07 => { let channel = get_u64(cursor)?; let length = get_u16(cursor)?; let data = get_bytes(cursor, length.into())?; Ok(Data(channel, data)) } - _ => Err(MessageError::UnknownMessage), + _ => Err(Error::MessageUnknown), } } } @@ -152,12 +145,17 @@ mod message_tests { #[test] fn round_trip() { assert_round_trip(Ping); + assert_round_trip(Hello( + 0x12, + 0x00, + vec!["One".to_string(), "Two".to_string(), "Three".to_string()], + )); + assert_round_trip(Hello(0x00, 0x01, vec![])); assert_round_trip(Connect(0x1234567890123456, 0x1234)); assert_round_trip(Connected(0x1234567890123456)); assert_round_trip(Close(0x1234567890123456)); - // assert_round_trip(Abort(0x1234567890123456)); - assert_round_trip(Closed(0x1234567890123456)); assert_round_trip(Refresh); + assert_round_trip(Ports(vec![])); assert_round_trip(Ports(vec![ PortDesc { port: 8080, @@ -170,38 +168,64 @@ mod message_tests { ])); assert_round_trip(Data(0x1234567890123456, vec![1, 2, 3, 4].into())); } + + #[test] + fn big_port_desc() { + // Strings are capped at 64k let's make a big one! + let char = String::from_utf8(vec![0xe0, 0xa0, 0x83]).unwrap(); + let mut str = String::with_capacity(128 * 1024); + while str.len() < 128 * 1024 { + str.push_str(&char); + } + + let msg = Ports(vec![PortDesc { + port: 8080, + desc: str, + }]); + msg.encode(); + } } -fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { if !cursor.has_remaining() { - return Err(MessageError::Incomplete); + return Err(Error::MessageIncomplete); } Ok(cursor.get_u8()) } -fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { if cursor.remaining() < 2 { - return Err(MessageError::Incomplete); + return Err(Error::MessageIncomplete); } Ok(cursor.get_u16()) } -fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { if cursor.remaining() < 8 { - return Err(MessageError::Incomplete); + return Err(Error::MessageIncomplete); } Ok(cursor.get_u64()) } -fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { +fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { if cursor.remaining() < length { - return Err(MessageError::Incomplete); + return Err(Error::MessageIncomplete); } Ok(cursor.copy_to_bytes(length)) } -pub fn slice_up_to(s: &str, max_len: usize) -> &str { +fn get_string(cursor: &mut Cursor<&[u8]>) -> Result { + let length = get_u16(cursor)?; + + let data = get_bytes(cursor, length.into())?; + match std::str::from_utf8(&data[..]) { + Ok(s) => Ok(s.to_owned()), + Err(_) => return Err(Error::MessageCorrupt), + } +} + +fn slice_up_to(s: &str, max_len: usize) -> &str { if max_len >= s.len() { return s; } @@ -211,3 +235,62 @@ pub fn slice_up_to(s: &str, max_len: usize) -> &str { } &s[..idx] } + +fn put_string(target: &mut T, str: &str) { + target.put_u16(str.len().try_into().expect("String is too long")); + target.put_slice(str.as_bytes()); +} + +// ---------------------------------------------------------------------------- +// Message IO + +pub struct MessageWriter { + writer: T, +} + +impl MessageWriter { + pub fn new(writer: T) -> MessageWriter { + MessageWriter { writer } + } + pub async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { + match self.write_impl(msg).await { + Err(e) => Err(Error::IO(e)), + Ok(ok) => Ok(ok), + } + } + async fn write_impl(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> { + // TODO: Optimize buffer usage please this is bad + // eprintln!("? {:?}", msg); + let mut buffer = msg.encode(); + self.writer + .write_u32(buffer.len().try_into().expect("Message too large")) + .await?; + self.writer.write_buf(&mut buffer).await?; + self.writer.flush().await?; + Ok(()) + } +} + +pub struct MessageReader { + reader: T, +} + +impl MessageReader { + pub fn new(reader: T) -> MessageReader { + MessageReader { reader } + } + pub async fn read(self: &mut Self) -> Result { + let frame_length = match self.reader.read_u32().await { + Ok(l) => l, + Err(e) => return Err(Error::IO(e)), + }; + + let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); + if let Err(e) = self.reader.read_buf(&mut data).await { + return Err(Error::IO(e)); + } + + let mut cursor = Cursor::new(&data[..]); + Message::decode(&mut cursor) + } +}