Protocol version, async pump, start some testing
This commit is contained in:
parent
763ecd190e
commit
6f906d80a7
6 changed files with 237 additions and 38 deletions
9
Cargo.lock
generated
9
Cargo.lock
generated
|
|
@ -23,6 +23,12 @@ version = "1.0.65"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
||||
|
||||
[[package]]
|
||||
name = "assert_matches"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
|
|
@ -166,9 +172,10 @@ checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
|
|||
|
||||
[[package]]
|
||||
name = "fwd"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_matches",
|
||||
"bytes",
|
||||
"crossterm",
|
||||
"home",
|
||||
|
|
|
|||
|
|
@ -21,5 +21,8 @@ tokio-stream = "0.1"
|
|||
toml = "0.5"
|
||||
tui = "0.19"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = "1"
|
||||
|
||||
[target.'cfg(target_os="linux")'.dependencies]
|
||||
procfs = "0.14.1"
|
||||
|
|
|
|||
|
|
@ -232,7 +232,8 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
|||
) -> Result<()> {
|
||||
// Wait for the server's announcement.
|
||||
if let Message::Hello(major, minor, _) = reader.read().await? {
|
||||
if major != 0 || minor > 1 {
|
||||
info!("Server Version: {major} {minor}");
|
||||
if major != 0 || minor > 2 {
|
||||
bail!("Unsupported remote protocol version {}.{}", major, minor);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -253,11 +254,13 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
|||
}
|
||||
} => {
|
||||
if let Err(e) = result {
|
||||
print!("Error sending refreshes\n");
|
||||
return Err(e.into());
|
||||
}
|
||||
},
|
||||
result = client_handle_messages(reader, events) => {
|
||||
if let Err(e) = result {
|
||||
print!("Error handling messages\n");
|
||||
return Err(e.into());
|
||||
}
|
||||
},
|
||||
|
|
@ -391,3 +394,91 @@ pub async fn run_client(remote: &str) {
|
|||
_ = client_connect_loop(remote, event_sender) => ()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
|
||||
struct Fixture {
|
||||
_server_read: MessageReader<DuplexStream>,
|
||||
server_write: MessageWriter<DuplexStream>,
|
||||
_event_receiver: Receiver<ui::UIEvent>,
|
||||
client_result: Option<tokio::task::JoinHandle<anyhow::Result<()>>>,
|
||||
}
|
||||
|
||||
impl Fixture {
|
||||
pub fn new() -> Self {
|
||||
let (server_read, client_write) = tokio::io::duplex(4096);
|
||||
let server_read = MessageReader::new(server_read);
|
||||
let client_write = MessageWriter::new(client_write);
|
||||
|
||||
let (client_read, server_write) = tokio::io::duplex(4096);
|
||||
let client_read = MessageReader::new(client_read);
|
||||
let server_write = MessageWriter::new(server_write);
|
||||
|
||||
let (event_sender, event_receiver) = mpsc::channel(1024);
|
||||
|
||||
let client_result = tokio::spawn(async move {
|
||||
let mut client_read = client_read;
|
||||
let mut client_write = client_write;
|
||||
client_main(
|
||||
0,
|
||||
&mut client_read,
|
||||
&mut client_write,
|
||||
event_sender,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Fixture {
|
||||
_server_read: server_read,
|
||||
server_write,
|
||||
_event_receiver: event_receiver,
|
||||
client_result: Some(client_result),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn shutdown(mut self) -> anyhow::Result<()> {
|
||||
let result = self.client_result.take();
|
||||
drop(self); // Side effect: close all streams.
|
||||
result.unwrap().await.expect("Unexpected join error")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic_hello_sync() {
|
||||
let mut t = Fixture::new();
|
||||
|
||||
t.server_write
|
||||
.write(Message::Hello(0, 2, vec![]))
|
||||
.await
|
||||
.expect("Error sending hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic_hello_high_minor() {
|
||||
let mut t = Fixture::new();
|
||||
|
||||
t.server_write
|
||||
.write(Message::Hello(0, 99, vec![]))
|
||||
.await
|
||||
.expect("Error sending hello");
|
||||
|
||||
assert_matches!(t.shutdown().await, Err(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic_hello_wrong_major() {
|
||||
let mut t = Fixture::new();
|
||||
|
||||
t.server_write
|
||||
.write(Message::Hello(99, 0, vec![]))
|
||||
.await
|
||||
.expect("Error sending hello");
|
||||
|
||||
assert_matches!(t.shutdown().await, Err(_));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,40 @@
|
|||
use anyhow::Result;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use std::io::Cursor;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Messages
|
||||
// Errors
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MessageError {
|
||||
pub enum Error {
|
||||
#[error("Message type unknown: {0}")]
|
||||
Unknown(u8),
|
||||
#[error("Message incomplete")]
|
||||
Incomplete,
|
||||
#[error("String contained invalid UTF8: {0}")]
|
||||
InvalidString(std::str::Utf8Error),
|
||||
#[error("IO Error occurred: {0}")]
|
||||
IO(std::io::Error),
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for Error {
|
||||
fn from(value: std::str::Utf8Error) -> Self {
|
||||
Self::InvalidString(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::IO(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Messages
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct PortDesc {
|
||||
pub port: u16,
|
||||
|
|
@ -23,10 +43,17 @@ pub struct PortDesc {
|
|||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum Message {
|
||||
Ping, // Ignored on both sides, can be used to test connection.
|
||||
Hello(u8, u8, Vec<String>), // Server info announcement: major version, minor version, headers.
|
||||
Refresh, // Request to refresh list of ports from client.
|
||||
Ports(Vec<PortDesc>), // List of available ports from server to client.
|
||||
// Ignored on both sides, can be used to test connection.
|
||||
Ping,
|
||||
|
||||
// Server info announcement: major version, minor version, headers.
|
||||
Hello(u8, u8, Vec<String>),
|
||||
|
||||
// Request to refresh list of ports from client.
|
||||
Refresh,
|
||||
|
||||
// List of available ports from server to client.
|
||||
Ports(Vec<PortDesc>),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
|
|
@ -46,7 +73,9 @@ impl Message {
|
|||
result.put_u8(0x01);
|
||||
result.put_u8(*major);
|
||||
result.put_u8(*minor);
|
||||
result.put_u16(details.len().try_into().expect("Too many details"));
|
||||
result.put_u16(
|
||||
details.len().try_into().expect("Too many details"),
|
||||
);
|
||||
for detail in details {
|
||||
put_string(result, detail);
|
||||
}
|
||||
|
|
@ -62,7 +91,8 @@ impl Message {
|
|||
result.put_u16(port.port);
|
||||
|
||||
// Port descriptions can be long, let's make sure they're not.
|
||||
let sliced = slice_up_to(&port.desc, u16::max_value().into());
|
||||
let sliced =
|
||||
slice_up_to(&port.desc, u16::max_value().into());
|
||||
put_string(result, sliced);
|
||||
}
|
||||
}
|
||||
|
|
@ -94,28 +124,28 @@ impl Message {
|
|||
}
|
||||
Ok(Ports(ports))
|
||||
}
|
||||
b => Err(MessageError::Unknown(b).into()),
|
||||
b => Err(Error::Unknown(b).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> {
|
||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8> {
|
||||
if !cursor.has_remaining() {
|
||||
return Err(MessageError::Incomplete);
|
||||
return Err(Error::Incomplete);
|
||||
}
|
||||
Ok(cursor.get_u8())
|
||||
}
|
||||
|
||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> {
|
||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16> {
|
||||
if cursor.remaining() < 2 {
|
||||
return Err(MessageError::Incomplete);
|
||||
return Err(Error::Incomplete);
|
||||
}
|
||||
Ok(cursor.get_u16())
|
||||
}
|
||||
|
||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> {
|
||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes> {
|
||||
if cursor.remaining() < length {
|
||||
return Err(MessageError::Incomplete);
|
||||
return Err(Error::Incomplete);
|
||||
}
|
||||
|
||||
Ok(cursor.copy_to_bytes(length))
|
||||
|
|
@ -255,10 +285,7 @@ mod message_tests {
|
|||
str.push_str(&char);
|
||||
}
|
||||
|
||||
let msg = Ports(vec![PortDesc {
|
||||
port: 8080,
|
||||
desc: str,
|
||||
}]);
|
||||
let msg = Ports(vec![PortDesc { port: 8080, desc: str }]);
|
||||
msg.encode();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,15 +2,34 @@ use crate::message::{Message, MessageReader, MessageWriter};
|
|||
use anyhow::Result;
|
||||
use log::{error, warn};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
mod refresh;
|
||||
|
||||
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||
reader: &mut MessageReader<Reader>,
|
||||
// We drive writes through an mpsc queue, because we not only handle requests
|
||||
// and responses from the client (refresh ports and the like) but also need
|
||||
// to asynchronously send messages to the client (open this URL, etc).
|
||||
async fn write_driver<Writer: AsyncWrite + Unpin>(
|
||||
messages: &mut mpsc::Receiver<Message>,
|
||||
writer: &mut MessageWriter<Writer>,
|
||||
) -> () {
|
||||
loop {
|
||||
match messages.recv().await {
|
||||
Some(m) => {
|
||||
writer.write(m).await.expect("Failed to write the message")
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle messages that the client sends to us.
|
||||
async fn server_loop<Reader: AsyncRead + Unpin>(
|
||||
reader: &mut MessageReader<Reader>,
|
||||
writer: &mut mpsc::Sender<Message>,
|
||||
) -> Result<()> {
|
||||
// The first message we send must be an announcement.
|
||||
writer.write(Message::Hello(0, 1, vec![])).await?;
|
||||
writer.send(Message::Hello(0, 2, vec![])).await?;
|
||||
|
||||
loop {
|
||||
use Message::*;
|
||||
|
|
@ -24,7 +43,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
|||
vec![]
|
||||
}
|
||||
};
|
||||
if let Err(e) = writer.write(Message::Ports(ports)).await {
|
||||
if let Err(e) = writer.send(Message::Ports(ports)).await {
|
||||
// Writer has been closed for some reason, we can just
|
||||
// quit.... I hope everything is OK?
|
||||
warn!("Warning: Error sending: {:?}", e);
|
||||
|
|
@ -35,24 +54,76 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn run_server() {
|
||||
let reader = BufReader::new(tokio::io::stdin());
|
||||
let mut writer = BufWriter::new(tokio::io::stdout());
|
||||
// Run the various server loops.
|
||||
async fn server_main<
|
||||
In: AsyncRead + Unpin + Send,
|
||||
Out: AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
stdin: In,
|
||||
stdout: Out,
|
||||
) -> Result<()> {
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut writer = BufWriter::new(stdout);
|
||||
|
||||
// Write the 8-byte synchronization marker.
|
||||
writer
|
||||
.write_u64(0x00_00_00_00_00_00_00_00)
|
||||
.await
|
||||
.expect("Error writing marker");
|
||||
writer.flush().await.expect("Error flushing");
|
||||
|
||||
if let Err(e) = writer.flush().await {
|
||||
eprintln!("Error writing sync marker: {:?}", e);
|
||||
return;
|
||||
}
|
||||
|
||||
let (mut sender, mut receiver) = mpsc::channel(10);
|
||||
let mut writer = MessageWriter::new(writer);
|
||||
let mut reader = MessageReader::new(reader);
|
||||
if let Err(e) = server_main(&mut reader, &mut writer).await {
|
||||
eprintln!("Error: {:?}", e);
|
||||
|
||||
let (_, result) = tokio::join!(
|
||||
write_driver(&mut receiver, &mut writer),
|
||||
server_loop(&mut reader, &mut sender)
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn run_server() {
|
||||
let stdin = tokio::io::stdin();
|
||||
let stdout = tokio::io::stdout();
|
||||
if let Err(e) = server_main(stdin, stdout).await {
|
||||
error!("Error: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
|
||||
async fn sync(client_read: &mut DuplexStream) {
|
||||
print!("[client] Waiting for server sync...\n");
|
||||
for _ in 0..8 {
|
||||
let b = client_read
|
||||
.read_u8()
|
||||
.await
|
||||
.expect("Error reading sync byte");
|
||||
assert_eq!(b, 0);
|
||||
}
|
||||
|
||||
let mut reader = MessageReader::new(client_read);
|
||||
print!("[client] Reading first message...\n");
|
||||
let msg = reader.read().await.expect("Error reading first message");
|
||||
assert_matches!(msg, Message::Hello(0, 2, _));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic_hello_sync() {
|
||||
let (server_read, _client_write) = tokio::io::duplex(4096);
|
||||
let (mut client_read, server_write) = tokio::io::duplex(4096);
|
||||
|
||||
tokio::spawn(async move {
|
||||
server_main(server_read, server_write)
|
||||
.await
|
||||
.expect("Error in server!");
|
||||
});
|
||||
|
||||
sync(&mut client_read).await;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
4
test.py
4
test.py
|
|
@ -1,5 +1,4 @@
|
|||
#!/bin/env python3
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
|
@ -14,6 +13,7 @@ def ssh(remote, *args):
|
|||
|
||||
def main(args):
|
||||
local("cargo", "build")
|
||||
local("cargo", "build", "--target=x86_64-unknown-linux-gnu")
|
||||
|
||||
remote = args[1]
|
||||
print(f"Copying file to {remote}...")
|
||||
|
|
@ -23,7 +23,7 @@ def main(args):
|
|||
capture_output=True,
|
||||
)
|
||||
|
||||
print(f"Starting process...")
|
||||
print("Starting process...")
|
||||
subprocess.run(["target/debug/fwd", remote])
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue