From 6f906d80a7ea52947abd95c956ba22f17b74b441 Mon Sep 17 00:00:00 2001 From: John Doty Date: Fri, 16 Dec 2022 13:57:52 -0800 Subject: [PATCH] Protocol version, async pump, start some testing --- Cargo.lock | 9 ++++- Cargo.toml | 3 ++ src/client/mod.rs | 93 +++++++++++++++++++++++++++++++++++++++++++- src/message.rs | 67 ++++++++++++++++++++++---------- src/server/mod.rs | 99 ++++++++++++++++++++++++++++++++++++++++------- test.py | 4 +- 6 files changed, 237 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 98bbbec..d71888e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "autocfg" version = "1.1.0" @@ -166,9 +172,10 @@ checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" [[package]] name = "fwd" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", + "assert_matches", "bytes", "crossterm", "home", diff --git a/Cargo.toml b/Cargo.toml index b60433b..42466f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,8 @@ tokio-stream = "0.1" toml = "0.5" tui = "0.19" +[dev-dependencies] +assert_matches = "1" + [target.'cfg(target_os="linux")'.dependencies] procfs = "0.14.1" diff --git a/src/client/mod.rs b/src/client/mod.rs index 908c9c3..7b64b1c 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -232,7 +232,8 @@ async fn client_main( ) -> Result<()> { // Wait for the server's announcement. if let Message::Hello(major, minor, _) = reader.read().await? { - if major != 0 || minor > 1 { + info!("Server Version: {major} {minor}"); + if major != 0 || minor > 2 { bail!("Unsupported remote protocol version {}.{}", major, minor); } } else { @@ -253,11 +254,13 @@ async fn client_main( } } => { if let Err(e) = result { + print!("Error sending refreshes\n"); return Err(e.into()); } }, result = client_handle_messages(reader, events) => { if let Err(e) = result { + print!("Error handling messages\n"); return Err(e.into()); } }, @@ -391,3 +394,91 @@ pub async fn run_client(remote: &str) { _ = client_connect_loop(remote, event_sender) => () } } + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use tokio::io::DuplexStream; + use tokio::sync::mpsc::Receiver; + + struct Fixture { + _server_read: MessageReader, + server_write: MessageWriter, + _event_receiver: Receiver, + client_result: Option>>, + } + + impl Fixture { + pub fn new() -> Self { + let (server_read, client_write) = tokio::io::duplex(4096); + let server_read = MessageReader::new(server_read); + let client_write = MessageWriter::new(client_write); + + let (client_read, server_write) = tokio::io::duplex(4096); + let client_read = MessageReader::new(client_read); + let server_write = MessageWriter::new(server_write); + + let (event_sender, event_receiver) = mpsc::channel(1024); + + let client_result = tokio::spawn(async move { + let mut client_read = client_read; + let mut client_write = client_write; + client_main( + 0, + &mut client_read, + &mut client_write, + event_sender, + ) + .await + }); + + Fixture { + _server_read: server_read, + server_write, + _event_receiver: event_receiver, + client_result: Some(client_result), + } + } + + pub async fn shutdown(mut self) -> anyhow::Result<()> { + let result = self.client_result.take(); + drop(self); // Side effect: close all streams. + result.unwrap().await.expect("Unexpected join error") + } + } + + #[tokio::test] + async fn basic_hello_sync() { + let mut t = Fixture::new(); + + t.server_write + .write(Message::Hello(0, 2, vec![])) + .await + .expect("Error sending hello"); + } + + #[tokio::test] + async fn basic_hello_high_minor() { + let mut t = Fixture::new(); + + t.server_write + .write(Message::Hello(0, 99, vec![])) + .await + .expect("Error sending hello"); + + assert_matches!(t.shutdown().await, Err(_)); + } + + #[tokio::test] + async fn basic_hello_wrong_major() { + let mut t = Fixture::new(); + + t.server_write + .write(Message::Hello(99, 0, vec![])) + .await + .expect("Error sending hello"); + + assert_matches!(t.shutdown().await, Err(_)); + } +} diff --git a/src/message.rs b/src/message.rs index b638866..66c81c4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,20 +1,40 @@ -use anyhow::Result; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::io::Cursor; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // ---------------------------------------------------------------------------- -// Messages +// Errors #[derive(Debug, Error)] -pub enum MessageError { +pub enum Error { #[error("Message type unknown: {0}")] Unknown(u8), #[error("Message incomplete")] Incomplete, + #[error("String contained invalid UTF8: {0}")] + InvalidString(std::str::Utf8Error), + #[error("IO Error occurred: {0}")] + IO(std::io::Error), } +impl From for Error { + fn from(value: std::str::Utf8Error) -> Self { + Self::InvalidString(value) + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +pub type Result = std::result::Result; + +// ---------------------------------------------------------------------------- +// Messages + #[derive(Debug, PartialEq, Clone)] pub struct PortDesc { pub port: u16, @@ -23,10 +43,17 @@ pub struct PortDesc { #[derive(Debug, PartialEq, Clone)] pub enum Message { - Ping, // Ignored on both sides, can be used to test connection. - Hello(u8, u8, Vec), // Server info announcement: major version, minor version, headers. - Refresh, // Request to refresh list of ports from client. - Ports(Vec), // List of available ports from server to client. + // Ignored on both sides, can be used to test connection. + Ping, + + // Server info announcement: major version, minor version, headers. + Hello(u8, u8, Vec), + + // Request to refresh list of ports from client. + Refresh, + + // List of available ports from server to client. + Ports(Vec), } impl Message { @@ -46,7 +73,9 @@ impl Message { result.put_u8(0x01); result.put_u8(*major); result.put_u8(*minor); - result.put_u16(details.len().try_into().expect("Too many details")); + result.put_u16( + details.len().try_into().expect("Too many details"), + ); for detail in details { put_string(result, detail); } @@ -62,7 +91,8 @@ impl Message { result.put_u16(port.port); // Port descriptions can be long, let's make sure they're not. - let sliced = slice_up_to(&port.desc, u16::max_value().into()); + let sliced = + slice_up_to(&port.desc, u16::max_value().into()); put_string(result, sliced); } } @@ -94,28 +124,28 @@ impl Message { } Ok(Ports(ports)) } - b => Err(MessageError::Unknown(b).into()), + b => Err(Error::Unknown(b).into()), } } } -fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { if !cursor.has_remaining() { - return Err(MessageError::Incomplete); + return Err(Error::Incomplete); } Ok(cursor.get_u8()) } -fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { +fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { if cursor.remaining() < 2 { - return Err(MessageError::Incomplete); + return Err(Error::Incomplete); } Ok(cursor.get_u16()) } -fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { +fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { if cursor.remaining() < length { - return Err(MessageError::Incomplete); + return Err(Error::Incomplete); } Ok(cursor.copy_to_bytes(length)) @@ -255,10 +285,7 @@ mod message_tests { str.push_str(&char); } - let msg = Ports(vec![PortDesc { - port: 8080, - desc: str, - }]); + let msg = Ports(vec![PortDesc { port: 8080, desc: str }]); msg.encode(); } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 0d0f7be..0309ae7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,15 +2,34 @@ use crate::message::{Message, MessageReader, MessageWriter}; use anyhow::Result; use log::{error, warn}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; +use tokio::sync::mpsc; mod refresh; -async fn server_main( - reader: &mut MessageReader, +// We drive writes through an mpsc queue, because we not only handle requests +// and responses from the client (refresh ports and the like) but also need +// to asynchronously send messages to the client (open this URL, etc). +async fn write_driver( + messages: &mut mpsc::Receiver, writer: &mut MessageWriter, +) -> () { + loop { + match messages.recv().await { + Some(m) => { + writer.write(m).await.expect("Failed to write the message") + } + None => break, + } + } +} + +// Handle messages that the client sends to us. +async fn server_loop( + reader: &mut MessageReader, + writer: &mut mpsc::Sender, ) -> Result<()> { // The first message we send must be an announcement. - writer.write(Message::Hello(0, 1, vec![])).await?; + writer.send(Message::Hello(0, 2, vec![])).await?; loop { use Message::*; @@ -24,7 +43,7 @@ async fn server_main( vec![] } }; - if let Err(e) = writer.write(Message::Ports(ports)).await { + if let Err(e) = writer.send(Message::Ports(ports)).await { // Writer has been closed for some reason, we can just // quit.... I hope everything is OK? warn!("Warning: Error sending: {:?}", e); @@ -35,24 +54,76 @@ async fn server_main( } } -pub async fn run_server() { - let reader = BufReader::new(tokio::io::stdin()); - let mut writer = BufWriter::new(tokio::io::stdout()); +// Run the various server loops. +async fn server_main< + In: AsyncRead + Unpin + Send, + Out: AsyncWrite + Unpin + Send, +>( + stdin: In, + stdout: Out, +) -> Result<()> { + let reader = BufReader::new(stdin); + let mut writer = BufWriter::new(stdout); // Write the 8-byte synchronization marker. writer .write_u64(0x00_00_00_00_00_00_00_00) .await .expect("Error writing marker"); + writer.flush().await.expect("Error flushing"); - if let Err(e) = writer.flush().await { - eprintln!("Error writing sync marker: {:?}", e); - return; - } - + let (mut sender, mut receiver) = mpsc::channel(10); let mut writer = MessageWriter::new(writer); let mut reader = MessageReader::new(reader); - if let Err(e) = server_main(&mut reader, &mut writer).await { - eprintln!("Error: {:?}", e); + + let (_, result) = tokio::join!( + write_driver(&mut receiver, &mut writer), + server_loop(&mut reader, &mut sender) + ); + result +} + +pub async fn run_server() { + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + if let Err(e) = server_main(stdin, stdout).await { + error!("Error: {:?}", e); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use tokio::io::{AsyncReadExt, DuplexStream}; + + async fn sync(client_read: &mut DuplexStream) { + print!("[client] Waiting for server sync...\n"); + for _ in 0..8 { + let b = client_read + .read_u8() + .await + .expect("Error reading sync byte"); + assert_eq!(b, 0); + } + + let mut reader = MessageReader::new(client_read); + print!("[client] Reading first message...\n"); + let msg = reader.read().await.expect("Error reading first message"); + assert_matches!(msg, Message::Hello(0, 2, _)); + } + + #[tokio::test] + async fn basic_hello_sync() { + let (server_read, _client_write) = tokio::io::duplex(4096); + let (mut client_read, server_write) = tokio::io::duplex(4096); + + tokio::spawn(async move { + server_main(server_read, server_write) + .await + .expect("Error in server!"); + }); + + sync(&mut client_read).await; } } diff --git a/test.py b/test.py index f4fc40a..435371c 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,4 @@ #!/bin/env python3 -import os import subprocess import sys @@ -14,6 +13,7 @@ def ssh(remote, *args): def main(args): local("cargo", "build") + local("cargo", "build", "--target=x86_64-unknown-linux-gnu") remote = args[1] print(f"Copying file to {remote}...") @@ -23,7 +23,7 @@ def main(args): capture_output=True, ) - print(f"Starting process...") + print("Starting process...") subprocess.run(["target/debug/fwd", remote])