Protocol version, async pump, start some testing
This commit is contained in:
parent
763ecd190e
commit
6f906d80a7
6 changed files with 237 additions and 38 deletions
9
Cargo.lock
generated
9
Cargo.lock
generated
|
|
@ -23,6 +23,12 @@ version = "1.0.65"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "assert_matches"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
|
@ -166,9 +172,10 @@ checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fwd"
|
name = "fwd"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"assert_matches",
|
||||||
"bytes",
|
"bytes",
|
||||||
"crossterm",
|
"crossterm",
|
||||||
"home",
|
"home",
|
||||||
|
|
|
||||||
|
|
@ -21,5 +21,8 @@ tokio-stream = "0.1"
|
||||||
toml = "0.5"
|
toml = "0.5"
|
||||||
tui = "0.19"
|
tui = "0.19"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
assert_matches = "1"
|
||||||
|
|
||||||
[target.'cfg(target_os="linux")'.dependencies]
|
[target.'cfg(target_os="linux")'.dependencies]
|
||||||
procfs = "0.14.1"
|
procfs = "0.14.1"
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,8 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Wait for the server's announcement.
|
// Wait for the server's announcement.
|
||||||
if let Message::Hello(major, minor, _) = reader.read().await? {
|
if let Message::Hello(major, minor, _) = reader.read().await? {
|
||||||
if major != 0 || minor > 1 {
|
info!("Server Version: {major} {minor}");
|
||||||
|
if major != 0 || minor > 2 {
|
||||||
bail!("Unsupported remote protocol version {}.{}", major, minor);
|
bail!("Unsupported remote protocol version {}.{}", major, minor);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -253,11 +254,13 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
}
|
}
|
||||||
} => {
|
} => {
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
|
print!("Error sending refreshes\n");
|
||||||
return Err(e.into());
|
return Err(e.into());
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
result = client_handle_messages(reader, events) => {
|
result = client_handle_messages(reader, events) => {
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
|
print!("Error handling messages\n");
|
||||||
return Err(e.into());
|
return Err(e.into());
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -391,3 +394,91 @@ pub async fn run_client(remote: &str) {
|
||||||
_ = client_connect_loop(remote, event_sender) => ()
|
_ = client_connect_loop(remote, event_sender) => ()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use tokio::io::DuplexStream;
|
||||||
|
use tokio::sync::mpsc::Receiver;
|
||||||
|
|
||||||
|
struct Fixture {
|
||||||
|
_server_read: MessageReader<DuplexStream>,
|
||||||
|
server_write: MessageWriter<DuplexStream>,
|
||||||
|
_event_receiver: Receiver<ui::UIEvent>,
|
||||||
|
client_result: Option<tokio::task::JoinHandle<anyhow::Result<()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Fixture {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let (server_read, client_write) = tokio::io::duplex(4096);
|
||||||
|
let server_read = MessageReader::new(server_read);
|
||||||
|
let client_write = MessageWriter::new(client_write);
|
||||||
|
|
||||||
|
let (client_read, server_write) = tokio::io::duplex(4096);
|
||||||
|
let client_read = MessageReader::new(client_read);
|
||||||
|
let server_write = MessageWriter::new(server_write);
|
||||||
|
|
||||||
|
let (event_sender, event_receiver) = mpsc::channel(1024);
|
||||||
|
|
||||||
|
let client_result = tokio::spawn(async move {
|
||||||
|
let mut client_read = client_read;
|
||||||
|
let mut client_write = client_write;
|
||||||
|
client_main(
|
||||||
|
0,
|
||||||
|
&mut client_read,
|
||||||
|
&mut client_write,
|
||||||
|
event_sender,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
Fixture {
|
||||||
|
_server_read: server_read,
|
||||||
|
server_write,
|
||||||
|
_event_receiver: event_receiver,
|
||||||
|
client_result: Some(client_result),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn shutdown(mut self) -> anyhow::Result<()> {
|
||||||
|
let result = self.client_result.take();
|
||||||
|
drop(self); // Side effect: close all streams.
|
||||||
|
result.unwrap().await.expect("Unexpected join error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn basic_hello_sync() {
|
||||||
|
let mut t = Fixture::new();
|
||||||
|
|
||||||
|
t.server_write
|
||||||
|
.write(Message::Hello(0, 2, vec![]))
|
||||||
|
.await
|
||||||
|
.expect("Error sending hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn basic_hello_high_minor() {
|
||||||
|
let mut t = Fixture::new();
|
||||||
|
|
||||||
|
t.server_write
|
||||||
|
.write(Message::Hello(0, 99, vec![]))
|
||||||
|
.await
|
||||||
|
.expect("Error sending hello");
|
||||||
|
|
||||||
|
assert_matches!(t.shutdown().await, Err(_));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn basic_hello_wrong_major() {
|
||||||
|
let mut t = Fixture::new();
|
||||||
|
|
||||||
|
t.server_write
|
||||||
|
.write(Message::Hello(99, 0, vec![]))
|
||||||
|
.await
|
||||||
|
.expect("Error sending hello");
|
||||||
|
|
||||||
|
assert_matches!(t.shutdown().await, Err(_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,40 @@
|
||||||
use anyhow::Result;
|
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Messages
|
// Errors
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum MessageError {
|
pub enum Error {
|
||||||
#[error("Message type unknown: {0}")]
|
#[error("Message type unknown: {0}")]
|
||||||
Unknown(u8),
|
Unknown(u8),
|
||||||
#[error("Message incomplete")]
|
#[error("Message incomplete")]
|
||||||
Incomplete,
|
Incomplete,
|
||||||
|
#[error("String contained invalid UTF8: {0}")]
|
||||||
|
InvalidString(std::str::Utf8Error),
|
||||||
|
#[error("IO Error occurred: {0}")]
|
||||||
|
IO(std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<std::str::Utf8Error> for Error {
|
||||||
|
fn from(value: std::str::Utf8Error) -> Self {
|
||||||
|
Self::InvalidString(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for Error {
|
||||||
|
fn from(value: std::io::Error) -> Self {
|
||||||
|
Self::IO(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Messages
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct PortDesc {
|
pub struct PortDesc {
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
|
@ -23,10 +43,17 @@ pub struct PortDesc {
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub enum Message {
|
pub enum Message {
|
||||||
Ping, // Ignored on both sides, can be used to test connection.
|
// Ignored on both sides, can be used to test connection.
|
||||||
Hello(u8, u8, Vec<String>), // Server info announcement: major version, minor version, headers.
|
Ping,
|
||||||
Refresh, // Request to refresh list of ports from client.
|
|
||||||
Ports(Vec<PortDesc>), // List of available ports from server to client.
|
// Server info announcement: major version, minor version, headers.
|
||||||
|
Hello(u8, u8, Vec<String>),
|
||||||
|
|
||||||
|
// Request to refresh list of ports from client.
|
||||||
|
Refresh,
|
||||||
|
|
||||||
|
// List of available ports from server to client.
|
||||||
|
Ports(Vec<PortDesc>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
|
@ -46,7 +73,9 @@ impl Message {
|
||||||
result.put_u8(0x01);
|
result.put_u8(0x01);
|
||||||
result.put_u8(*major);
|
result.put_u8(*major);
|
||||||
result.put_u8(*minor);
|
result.put_u8(*minor);
|
||||||
result.put_u16(details.len().try_into().expect("Too many details"));
|
result.put_u16(
|
||||||
|
details.len().try_into().expect("Too many details"),
|
||||||
|
);
|
||||||
for detail in details {
|
for detail in details {
|
||||||
put_string(result, detail);
|
put_string(result, detail);
|
||||||
}
|
}
|
||||||
|
|
@ -62,7 +91,8 @@ impl Message {
|
||||||
result.put_u16(port.port);
|
result.put_u16(port.port);
|
||||||
|
|
||||||
// Port descriptions can be long, let's make sure they're not.
|
// Port descriptions can be long, let's make sure they're not.
|
||||||
let sliced = slice_up_to(&port.desc, u16::max_value().into());
|
let sliced =
|
||||||
|
slice_up_to(&port.desc, u16::max_value().into());
|
||||||
put_string(result, sliced);
|
put_string(result, sliced);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -94,28 +124,28 @@ impl Message {
|
||||||
}
|
}
|
||||||
Ok(Ports(ports))
|
Ok(Ports(ports))
|
||||||
}
|
}
|
||||||
b => Err(MessageError::Unknown(b).into()),
|
b => Err(Error::Unknown(b).into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8, MessageError> {
|
fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result<u8> {
|
||||||
if !cursor.has_remaining() {
|
if !cursor.has_remaining() {
|
||||||
return Err(MessageError::Incomplete);
|
return Err(Error::Incomplete);
|
||||||
}
|
}
|
||||||
Ok(cursor.get_u8())
|
Ok(cursor.get_u8())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16, MessageError> {
|
fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result<u16> {
|
||||||
if cursor.remaining() < 2 {
|
if cursor.remaining() < 2 {
|
||||||
return Err(MessageError::Incomplete);
|
return Err(Error::Incomplete);
|
||||||
}
|
}
|
||||||
Ok(cursor.get_u16())
|
Ok(cursor.get_u16())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes, MessageError> {
|
fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result<Bytes> {
|
||||||
if cursor.remaining() < length {
|
if cursor.remaining() < length {
|
||||||
return Err(MessageError::Incomplete);
|
return Err(Error::Incomplete);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(cursor.copy_to_bytes(length))
|
Ok(cursor.copy_to_bytes(length))
|
||||||
|
|
@ -255,10 +285,7 @@ mod message_tests {
|
||||||
str.push_str(&char);
|
str.push_str(&char);
|
||||||
}
|
}
|
||||||
|
|
||||||
let msg = Ports(vec![PortDesc {
|
let msg = Ports(vec![PortDesc { port: 8080, desc: str }]);
|
||||||
port: 8080,
|
|
||||||
desc: str,
|
|
||||||
}]);
|
|
||||||
msg.encode();
|
msg.encode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,34 @@ use crate::message::{Message, MessageReader, MessageWriter};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::{error, warn};
|
use log::{error, warn};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
mod refresh;
|
mod refresh;
|
||||||
|
|
||||||
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
// We drive writes through an mpsc queue, because we not only handle requests
|
||||||
reader: &mut MessageReader<Reader>,
|
// and responses from the client (refresh ports and the like) but also need
|
||||||
|
// to asynchronously send messages to the client (open this URL, etc).
|
||||||
|
async fn write_driver<Writer: AsyncWrite + Unpin>(
|
||||||
|
messages: &mut mpsc::Receiver<Message>,
|
||||||
writer: &mut MessageWriter<Writer>,
|
writer: &mut MessageWriter<Writer>,
|
||||||
|
) -> () {
|
||||||
|
loop {
|
||||||
|
match messages.recv().await {
|
||||||
|
Some(m) => {
|
||||||
|
writer.write(m).await.expect("Failed to write the message")
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle messages that the client sends to us.
|
||||||
|
async fn server_loop<Reader: AsyncRead + Unpin>(
|
||||||
|
reader: &mut MessageReader<Reader>,
|
||||||
|
writer: &mut mpsc::Sender<Message>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// The first message we send must be an announcement.
|
// The first message we send must be an announcement.
|
||||||
writer.write(Message::Hello(0, 1, vec![])).await?;
|
writer.send(Message::Hello(0, 2, vec![])).await?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
use Message::*;
|
use Message::*;
|
||||||
|
|
@ -24,7 +43,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if let Err(e) = writer.write(Message::Ports(ports)).await {
|
if let Err(e) = writer.send(Message::Ports(ports)).await {
|
||||||
// Writer has been closed for some reason, we can just
|
// Writer has been closed for some reason, we can just
|
||||||
// quit.... I hope everything is OK?
|
// quit.... I hope everything is OK?
|
||||||
warn!("Warning: Error sending: {:?}", e);
|
warn!("Warning: Error sending: {:?}", e);
|
||||||
|
|
@ -35,24 +54,76 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_server() {
|
// Run the various server loops.
|
||||||
let reader = BufReader::new(tokio::io::stdin());
|
async fn server_main<
|
||||||
let mut writer = BufWriter::new(tokio::io::stdout());
|
In: AsyncRead + Unpin + Send,
|
||||||
|
Out: AsyncWrite + Unpin + Send,
|
||||||
|
>(
|
||||||
|
stdin: In,
|
||||||
|
stdout: Out,
|
||||||
|
) -> Result<()> {
|
||||||
|
let reader = BufReader::new(stdin);
|
||||||
|
let mut writer = BufWriter::new(stdout);
|
||||||
|
|
||||||
// Write the 8-byte synchronization marker.
|
// Write the 8-byte synchronization marker.
|
||||||
writer
|
writer
|
||||||
.write_u64(0x00_00_00_00_00_00_00_00)
|
.write_u64(0x00_00_00_00_00_00_00_00)
|
||||||
.await
|
.await
|
||||||
.expect("Error writing marker");
|
.expect("Error writing marker");
|
||||||
|
writer.flush().await.expect("Error flushing");
|
||||||
|
|
||||||
if let Err(e) = writer.flush().await {
|
let (mut sender, mut receiver) = mpsc::channel(10);
|
||||||
eprintln!("Error writing sync marker: {:?}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut writer = MessageWriter::new(writer);
|
let mut writer = MessageWriter::new(writer);
|
||||||
let mut reader = MessageReader::new(reader);
|
let mut reader = MessageReader::new(reader);
|
||||||
if let Err(e) = server_main(&mut reader, &mut writer).await {
|
|
||||||
eprintln!("Error: {:?}", e);
|
let (_, result) = tokio::join!(
|
||||||
|
write_driver(&mut receiver, &mut writer),
|
||||||
|
server_loop(&mut reader, &mut sender)
|
||||||
|
);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_server() {
|
||||||
|
let stdin = tokio::io::stdin();
|
||||||
|
let stdout = tokio::io::stdout();
|
||||||
|
if let Err(e) = server_main(stdin, stdout).await {
|
||||||
|
error!("Error: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||||
|
|
||||||
|
async fn sync(client_read: &mut DuplexStream) {
|
||||||
|
print!("[client] Waiting for server sync...\n");
|
||||||
|
for _ in 0..8 {
|
||||||
|
let b = client_read
|
||||||
|
.read_u8()
|
||||||
|
.await
|
||||||
|
.expect("Error reading sync byte");
|
||||||
|
assert_eq!(b, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut reader = MessageReader::new(client_read);
|
||||||
|
print!("[client] Reading first message...\n");
|
||||||
|
let msg = reader.read().await.expect("Error reading first message");
|
||||||
|
assert_matches!(msg, Message::Hello(0, 2, _));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn basic_hello_sync() {
|
||||||
|
let (server_read, _client_write) = tokio::io::duplex(4096);
|
||||||
|
let (mut client_read, server_write) = tokio::io::duplex(4096);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
server_main(server_read, server_write)
|
||||||
|
.await
|
||||||
|
.expect("Error in server!");
|
||||||
|
});
|
||||||
|
|
||||||
|
sync(&mut client_read).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
4
test.py
4
test.py
|
|
@ -1,5 +1,4 @@
|
||||||
#!/bin/env python3
|
#!/bin/env python3
|
||||||
import os
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
@ -14,6 +13,7 @@ def ssh(remote, *args):
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
local("cargo", "build")
|
local("cargo", "build")
|
||||||
|
local("cargo", "build", "--target=x86_64-unknown-linux-gnu")
|
||||||
|
|
||||||
remote = args[1]
|
remote = args[1]
|
||||||
print(f"Copying file to {remote}...")
|
print(f"Copying file to {remote}...")
|
||||||
|
|
@ -23,7 +23,7 @@ def main(args):
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Starting process...")
|
print("Starting process...")
|
||||||
subprocess.run(["target/debug/fwd", remote])
|
subprocess.run(["target/debug/fwd", remote])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue