Protocol version, async pump, start some testing

This commit is contained in:
John Doty 2022-12-16 13:57:52 -08:00
parent 763ecd190e
commit 6f906d80a7
6 changed files with 237 additions and 38 deletions

9
Cargo.lock generated
View file

@ -23,6 +23,12 @@ version = "1.0.65"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
[[package]]
name = "assert_matches"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.1.0" version = "1.1.0"
@ -166,9 +172,10 @@ checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
[[package]] [[package]]
name = "fwd" name = "fwd"
version = "0.1.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assert_matches",
"bytes", "bytes",
"crossterm", "crossterm",
"home", "home",

View file

@ -21,5 +21,8 @@ tokio-stream = "0.1"
toml = "0.5" toml = "0.5"
tui = "0.19" tui = "0.19"
[dev-dependencies]
assert_matches = "1"
[target.'cfg(target_os="linux")'.dependencies] [target.'cfg(target_os="linux")'.dependencies]
procfs = "0.14.1" procfs = "0.14.1"

View file

@ -232,7 +232,8 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
) -> Result<()> { ) -> Result<()> {
// Wait for the server's announcement. // Wait for the server's announcement.
if let Message::Hello(major, minor, _) = reader.read().await? { if let Message::Hello(major, minor, _) = reader.read().await? {
if major != 0 || minor > 1 { info!("Server Version: {major} {minor}");
if major != 0 || minor > 2 {
bail!("Unsupported remote protocol version {}.{}", major, minor); bail!("Unsupported remote protocol version {}.{}", major, minor);
} }
} else { } else {
@ -253,11 +254,13 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
} }
} => { } => {
if let Err(e) = result { if let Err(e) = result {
print!("Error sending refreshes\n");
return Err(e.into()); return Err(e.into());
} }
}, },
result = client_handle_messages(reader, events) => { result = client_handle_messages(reader, events) => {
if let Err(e) = result { if let Err(e) = result {
print!("Error handling messages\n");
return Err(e.into()); return Err(e.into());
} }
}, },
@ -391,3 +394,91 @@ pub async fn run_client(remote: &str) {
_ = client_connect_loop(remote, event_sender) => () _ = client_connect_loop(remote, event_sender) => ()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use tokio::io::DuplexStream;
use tokio::sync::mpsc::Receiver;
struct Fixture {
_server_read: MessageReader<DuplexStream>,
server_write: MessageWriter<DuplexStream>,
_event_receiver: Receiver<ui::UIEvent>,
client_result: Option<tokio::task::JoinHandle<anyhow::Result<()>>>,
}
impl Fixture {
pub fn new() -> Self {
let (server_read, client_write) = tokio::io::duplex(4096);
let server_read = MessageReader::new(server_read);
let client_write = MessageWriter::new(client_write);
let (client_read, server_write) = tokio::io::duplex(4096);
let client_read = MessageReader::new(client_read);
let server_write = MessageWriter::new(server_write);
let (event_sender, event_receiver) = mpsc::channel(1024);
let client_result = tokio::spawn(async move {
let mut client_read = client_read;
let mut client_write = client_write;
client_main(
0,
&mut client_read,
&mut client_write,
event_sender,
)
.await
});
Fixture {
_server_read: server_read,
server_write,
_event_receiver: event_receiver,
client_result: Some(client_result),
}
}
pub async fn shutdown(mut self) -> anyhow::Result<()> {
let result = self.client_result.take();
drop(self); // Side effect: close all streams.
result.unwrap().await.expect("Unexpected join error")
}
}
#[tokio::test]
async fn basic_hello_sync() {
let mut t = Fixture::new();
t.server_write
.write(Message::Hello(0, 2, vec![]))
.await
.expect("Error sending hello");
}
#[tokio::test]
async fn basic_hello_high_minor() {
let mut t = Fixture::new();
t.server_write
.write(Message::Hello(0, 99, vec![]))
.await
.expect("Error sending hello");
assert_matches!(t.shutdown().await, Err(_));
}
#[tokio::test]
async fn basic_hello_wrong_major() {
let mut t = Fixture::new();
t.server_write
.write(Message::Hello(99, 0, vec![]))
.await
.expect("Error sending hello");
assert_matches!(t.shutdown().await, Err(_));
}
}

View file

@ -1,20 +1,40 @@
use anyhow::Result;
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io::Cursor; use std::io::Cursor;
use thiserror::Error; use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Messages // Errors
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum MessageError { pub enum Error {
#[error("Message type unknown: {0}")] #[error("Message type unknown: {0}")]
Unknown(u8), Unknown(u8),
#[error("Message incomplete")] #[error("Message incomplete")]
Incomplete, Incomplete,
#[error("String contained invalid UTF8: {0}")]
InvalidString(std::str::Utf8Error),
#[error("IO Error occurred: {0}")]
IO(std::io::Error),
} }
impl From<std::str::Utf8Error> for Error {
fn from(value: std::str::Utf8Error) -> Self {
Self::InvalidString(value)
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Self::IO(value)
}
}
pub type Result<T> = std::result::Result<T, Error>;
// ----------------------------------------------------------------------------
// Messages
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct PortDesc { pub struct PortDesc {
pub port: u16, pub port: u16,
@ -23,10 +43,17 @@ pub struct PortDesc {
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub enum Message { pub enum Message {
Ping, // Ignored on both sides, can be used to test connection. // Ignored on both sides, can be used to test connection.
Hello(u8, u8, Vec<String>), // Server info announcement: major version, minor version, headers. Ping,
Refresh, // Request to refresh list of ports from client.
Ports(Vec<PortDesc>), // List of available ports from server to client. // Server info announcement: major version, minor version, headers.
Hello(u8, u8, Vec<String>),
// Request to refresh list of ports from client.
Refresh,
// List of available ports from server to client.
Ports(Vec<PortDesc>),
} }
impl Message { impl Message {
@ -46,7 +73,9 @@ impl Message {
result.put_u8(0x01); result.put_u8(0x01);
result.put_u8(*major); result.put_u8(*major);
result.put_u8(*minor); result.put_u8(*minor);
result.put_u16(details.len().try_into().expect("Too many details")); result.put_u16(
details.len().try_into().expect("Too many details"),
);
for detail in details { for detail in details {
put_string(result, detail); put_string(result, detail);
} }
@ -62,7 +91,8 @@ impl Message {
result.put_u16(port.port); result.put_u16(port.port);
// Port descriptions can be long, let's make sure they're not. // Port descriptions can be long, let's make sure they're not.
let sliced = slice_up_to(&port.desc, u16::max_value().into()); let sliced =
slice_up_to(&port.desc, u16::max_value().into());
put_string(result, sliced); put_string(result, sliced);
} }
} }
@ -94,28 +124,28 @@ impl Message {
} }
Ok(Ports(ports)) Ok(Ports(ports))
} }
b => Err(MessageError::Unknown(b).into()), b => Err(Error::Unknown(b).into()),
} }
} }
} }
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> { fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8> {
if !cursor.has_remaining() { if !cursor.has_remaining() {
return Err(MessageError::Incomplete); return Err(Error::Incomplete);
} }
Ok(cursor.get_u8()) Ok(cursor.get_u8())
} }
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> { fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16> {
if cursor.remaining() < 2 { if cursor.remaining() < 2 {
return Err(MessageError::Incomplete); return Err(Error::Incomplete);
} }
Ok(cursor.get_u16()) Ok(cursor.get_u16())
} }
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> { fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes> {
if cursor.remaining() < length { if cursor.remaining() < length {
return Err(MessageError::Incomplete); return Err(Error::Incomplete);
} }
Ok(cursor.copy_to_bytes(length)) Ok(cursor.copy_to_bytes(length))
@ -255,10 +285,7 @@ mod message_tests {
str.push_str(&char); str.push_str(&char);
} }
let msg = Ports(vec![PortDesc { let msg = Ports(vec![PortDesc { port: 8080, desc: str }]);
port: 8080,
desc: str,
}]);
msg.encode(); msg.encode();
} }
} }

View file

@ -2,15 +2,34 @@ use crate::message::{Message, MessageReader, MessageWriter};
use anyhow::Result; use anyhow::Result;
use log::{error, warn}; use log::{error, warn};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::mpsc;
mod refresh; mod refresh;
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>( // We drive writes through an mpsc queue, because we not only handle requests
reader: &mut MessageReader<Reader>, // and responses from the client (refresh ports and the like) but also need
// to asynchronously send messages to the client (open this URL, etc).
async fn write_driver<Writer: AsyncWrite + Unpin>(
messages: &mut mpsc::Receiver<Message>,
writer: &mut MessageWriter<Writer>, writer: &mut MessageWriter<Writer>,
) -> () {
loop {
match messages.recv().await {
Some(m) => {
writer.write(m).await.expect("Failed to write the message")
}
None => break,
}
}
}
// Handle messages that the client sends to us.
async fn server_loop<Reader: AsyncRead + Unpin>(
reader: &mut MessageReader<Reader>,
writer: &mut mpsc::Sender<Message>,
) -> Result<()> { ) -> Result<()> {
// The first message we send must be an announcement. // The first message we send must be an announcement.
writer.write(Message::Hello(0, 1, vec![])).await?; writer.send(Message::Hello(0, 2, vec![])).await?;
loop { loop {
use Message::*; use Message::*;
@ -24,7 +43,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
vec![] vec![]
} }
}; };
if let Err(e) = writer.write(Message::Ports(ports)).await { if let Err(e) = writer.send(Message::Ports(ports)).await {
// Writer has been closed for some reason, we can just // Writer has been closed for some reason, we can just
// quit.... I hope everything is OK? // quit.... I hope everything is OK?
warn!("Warning: Error sending: {:?}", e); warn!("Warning: Error sending: {:?}", e);
@ -35,24 +54,76 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
} }
} }
pub async fn run_server() { // Run the various server loops.
let reader = BufReader::new(tokio::io::stdin()); async fn server_main<
let mut writer = BufWriter::new(tokio::io::stdout()); In: AsyncRead + Unpin + Send,
Out: AsyncWrite + Unpin + Send,
>(
stdin: In,
stdout: Out,
) -> Result<()> {
let reader = BufReader::new(stdin);
let mut writer = BufWriter::new(stdout);
// Write the 8-byte synchronization marker. // Write the 8-byte synchronization marker.
writer writer
.write_u64(0x00_00_00_00_00_00_00_00) .write_u64(0x00_00_00_00_00_00_00_00)
.await .await
.expect("Error writing marker"); .expect("Error writing marker");
writer.flush().await.expect("Error flushing");
if let Err(e) = writer.flush().await { let (mut sender, mut receiver) = mpsc::channel(10);
eprintln!("Error writing sync marker: {:?}", e);
return;
}
let mut writer = MessageWriter::new(writer); let mut writer = MessageWriter::new(writer);
let mut reader = MessageReader::new(reader); let mut reader = MessageReader::new(reader);
if let Err(e) = server_main(&mut reader, &mut writer).await {
eprintln!("Error: {:?}", e); let (_, result) = tokio::join!(
write_driver(&mut receiver, &mut writer),
server_loop(&mut reader, &mut sender)
);
result
}
pub async fn run_server() {
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
if let Err(e) = server_main(stdin, stdout).await {
error!("Error: {:?}", e);
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use tokio::io::{AsyncReadExt, DuplexStream};
async fn sync(client_read: &mut DuplexStream) {
print!("[client] Waiting for server sync...\n");
for _ in 0..8 {
let b = client_read
.read_u8()
.await
.expect("Error reading sync byte");
assert_eq!(b, 0);
}
let mut reader = MessageReader::new(client_read);
print!("[client] Reading first message...\n");
let msg = reader.read().await.expect("Error reading first message");
assert_matches!(msg, Message::Hello(0, 2, _));
}
#[tokio::test]
async fn basic_hello_sync() {
let (server_read, _client_write) = tokio::io::duplex(4096);
let (mut client_read, server_write) = tokio::io::duplex(4096);
tokio::spawn(async move {
server_main(server_read, server_write)
.await
.expect("Error in server!");
});
sync(&mut client_read).await;
} }
} }

View file

@ -1,5 +1,4 @@
#!/bin/env python3 #!/bin/env python3
import os
import subprocess import subprocess
import sys import sys
@ -14,6 +13,7 @@ def ssh(remote, *args):
def main(args): def main(args):
local("cargo", "build") local("cargo", "build")
local("cargo", "build", "--target=x86_64-unknown-linux-gnu")
remote = args[1] remote = args[1]
print(f"Copying file to {remote}...") print(f"Copying file to {remote}...")
@ -23,7 +23,7 @@ def main(args):
capture_output=True, capture_output=True,
) )
print(f"Starting process...") print("Starting process...")
subprocess.run(["target/debug/fwd", remote]) subprocess.run(["target/debug/fwd", remote])