Protocol version, async pump, start some testing

This commit is contained in:
John Doty 2022-12-16 13:57:52 -08:00
parent 763ecd190e
commit 6f906d80a7
6 changed files with 237 additions and 38 deletions

View file

@ -2,15 +2,34 @@ use crate::message::{Message, MessageReader, MessageWriter};
use anyhow::Result;
use log::{error, warn};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::mpsc;
mod refresh;
async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>,
// We drive writes through an mpsc queue, because we not only handle requests
// and responses from the client (refresh ports and the like) but also need
// to asynchronously send messages to the client (open this URL, etc).
async fn write_driver<Writer: AsyncWrite + Unpin>(
messages: &mut mpsc::Receiver<Message>,
writer: &mut MessageWriter<Writer>,
) -> () {
loop {
match messages.recv().await {
Some(m) => {
writer.write(m).await.expect("Failed to write the message")
}
None => break,
}
}
}
// Handle messages that the client sends to us.
async fn server_loop<Reader: AsyncRead + Unpin>(
reader: &mut MessageReader<Reader>,
writer: &mut mpsc::Sender<Message>,
) -> Result<()> {
// The first message we send must be an announcement.
writer.write(Message::Hello(0, 1, vec![])).await?;
writer.send(Message::Hello(0, 2, vec![])).await?;
loop {
use Message::*;
@ -24,7 +43,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
vec![]
}
};
if let Err(e) = writer.write(Message::Ports(ports)).await {
if let Err(e) = writer.send(Message::Ports(ports)).await {
// Writer has been closed for some reason, we can just
// quit.... I hope everything is OK?
warn!("Warning: Error sending: {:?}", e);
@ -35,24 +54,76 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
}
}
pub async fn run_server() {
let reader = BufReader::new(tokio::io::stdin());
let mut writer = BufWriter::new(tokio::io::stdout());
// Run the various server loops.
async fn server_main<
In: AsyncRead + Unpin + Send,
Out: AsyncWrite + Unpin + Send,
>(
stdin: In,
stdout: Out,
) -> Result<()> {
let reader = BufReader::new(stdin);
let mut writer = BufWriter::new(stdout);
// Write the 8-byte synchronization marker.
writer
.write_u64(0x00_00_00_00_00_00_00_00)
.await
.expect("Error writing marker");
writer.flush().await.expect("Error flushing");
if let Err(e) = writer.flush().await {
eprintln!("Error writing sync marker: {:?}", e);
return;
}
let (mut sender, mut receiver) = mpsc::channel(10);
let mut writer = MessageWriter::new(writer);
let mut reader = MessageReader::new(reader);
if let Err(e) = server_main(&mut reader, &mut writer).await {
eprintln!("Error: {:?}", e);
let (_, result) = tokio::join!(
write_driver(&mut receiver, &mut writer),
server_loop(&mut reader, &mut sender)
);
result
}
pub async fn run_server() {
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
if let Err(e) = server_main(stdin, stdout).await {
error!("Error: {:?}", e);
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use tokio::io::{AsyncReadExt, DuplexStream};
async fn sync(client_read: &mut DuplexStream) {
print!("[client] Waiting for server sync...\n");
for _ in 0..8 {
let b = client_read
.read_u8()
.await
.expect("Error reading sync byte");
assert_eq!(b, 0);
}
let mut reader = MessageReader::new(client_read);
print!("[client] Reading first message...\n");
let msg = reader.read().await.expect("Error reading first message");
assert_matches!(msg, Message::Hello(0, 2, _));
}
#[tokio::test]
async fn basic_hello_sync() {
let (server_read, _client_write) = tokio::io::duplex(4096);
let (mut client_read, server_write) = tokio::io::duplex(4096);
tokio::spawn(async move {
server_main(server_read, server_write)
.await
.expect("Error in server!");
});
sync(&mut client_read).await;
}
}