From b98d28bd9068fb3a3c7ffe9c09884ba80ac3e51e Mon Sep 17 00:00:00 2001 From: John Doty Date: Sat, 8 Oct 2022 21:56:30 -0700 Subject: [PATCH] Doc doc --- src/connection.rs | 122 +++++++++++++++++++++++++++++++++++++++++++++- src/lib.rs | 113 ++++++++++-------------------------------- 2 files changed, 146 insertions(+), 89 deletions(-) diff --git a/src/connection.rs b/src/connection.rs index 87d62a4..2288d2b 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,9 +1,12 @@ use crate::message::Message; use crate::Error; use bytes::{Bytes, BytesMut}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::mpsc; +use tokio::sync::oneshot; const MAX_PACKET: usize = u16::max_value() as usize; @@ -42,7 +45,7 @@ async fn connection_read( } // TODO: Flow control here, wait for the packet to be acknowleged so - // there isn't head-of-line blocking or infinite bufferingon the + // there isn't head-of-line blocking or infinite buffering on the // remote side. Also buffer re-use! }; @@ -100,3 +103,120 @@ pub async fn process( } } } + +/// The connection structure tracks the various channels used to communicate +/// with an "open" connection. +struct Connection { + /// The callback for the connected message, if we haven't already + /// connected across the channel. Realistically, this only ever has a + /// value on the client side, where we wait for the server side to + /// connect and then acknowlege that the connection. + connected: Option>, + + /// The channel where the connection receives [Bytes] to be written to + /// the socket. + data: mpsc::Sender, +} + +struct ConnectionTableState { + next_id: u64, + connections: HashMap, +} + +/// A tracking structure for connections. This structure is thread-safe and +/// so can be used to track new connections from as many concurrent listeners +/// as you would like. +#[derive(Clone)] +pub struct ConnectionTable { + connections: Arc>, +} + +impl ConnectionTable { + /// Create a new, empty connection table. + pub fn new() -> ConnectionTable { + ConnectionTable { + connections: Arc::new(Mutex::new(ConnectionTableState { + next_id: 0, + connections: HashMap::new(), + })), + } + } + + /// Allocate a new connection on the client side. The connection is + /// assigned a new ID, which is returned to the caller. + pub fn alloc( + self: &mut Self, + connected: oneshot::Sender<()>, + data: mpsc::Sender, + ) -> u64 { + let mut tbl = self.connections.lock().unwrap(); + let id = tbl.next_id; + tbl.next_id += 1; + tbl.connections.insert( + id, + Connection { + connected: Some(connected), + data, + }, + ); + id + } + + /// Add a connection to the table on the server side. The client sent us + /// the ID to use, so we don't need to allocate it, and obviously we + /// aren't going to be waiting for the connection to be "connected." + pub fn add(self: &mut Self, id: u64, data: mpsc::Sender) { + let mut tbl = self.connections.lock().unwrap(); + tbl.connections.insert( + id, + Connection { + connected: None, + data, + }, + ); + } + + /// Mark a connection as being "connected", on the client side, where we + /// wait for the server to tell us such things. Note that this gets used + /// for a successful connection; on a failure just call [remove]. + pub fn connected(self: &mut Self, id: u64) { + let connected = { + let mut tbl = self.connections.lock().unwrap(); + if let Some(c) = tbl.connections.get_mut(&id) { + c.connected.take() + } else { + None + } + }; + + if let Some(connected) = connected { + _ = connected.send(()); + } + } + + /// Tell a connection that we have received data. This gets used on both + /// sides of the pipe; if the connection exists and is still active it + /// will send the data out through its socket. + pub async fn receive(self: &Self, id: u64, buf: Bytes) { + let data = { + let tbl = self.connections.lock().unwrap(); + if let Some(connection) = tbl.connections.get(&id) { + Some(connection.data.clone()) + } else { + None + } + }; + + if let Some(data) = data { + _ = data.send(buf).await; + } + } + + /// Remove a connection from the table, effectively closing it. This will + /// close all the pipes that the connection uses to receive data from the + /// other side, performing a cleanup on our "write" side of the socket. + pub fn remove(self: &mut Self, id: u64) { + let mut tbl = self.connections.lock().unwrap(); + tbl.connections.remove(&id); + } +} diff --git a/src/lib.rs b/src/lib.rs index c82c149..aaa61df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,7 @@ -use bytes::Bytes; +use connection::ConnectionTable; use message::{Message, MessageReader, MessageWriter}; use std::collections::HashMap; use std::net::{Ipv4Addr, SocketAddrV4}; -use std::sync::{Arc, Mutex}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::process; @@ -70,6 +69,26 @@ impl PartialEq for Error { } } +// ---------------------------------------------------------------------------- +// Write Management + +/// Gathers writes from an mpsc queue and writes them to the specified +/// writer. +/// +/// This is kind of an odd function. It raises a lot of questions. +/// +/// *Why can't this just be a wrapper function on top of MessageWriter that +/// everybody calls?* Well, we could do that, but we also need to synchronize +/// writes to the underlying stream. +/// +/// *Why not use an async mutex?* Because this function has a nice side +/// benefit: if it ever quits, we're *either* doing an orderly shutdown +/// (because the last write end of this channel closed) *or* the remote +/// connection has closed. [client_main] uses this fact to its advantage to +/// detect when the connection has failed. +/// +/// At some point we may even automatically reconnect in response! +/// async fn pump_write( messages: &mut mpsc::Receiver, writer: &mut MessageWriter, @@ -83,92 +102,6 @@ async fn pump_write( // ---------------------------------------------------------------------------- // Server -struct Connection { - connected: Option>, - data: mpsc::Sender, -} - -struct ConnectionTableState { - next_id: u64, - connections: HashMap, -} - -#[derive(Clone)] -struct ConnectionTable { - connections: Arc>, -} - -impl ConnectionTable { - fn new() -> ConnectionTable { - ConnectionTable { - connections: Arc::new(Mutex::new(ConnectionTableState { - next_id: 0, - connections: HashMap::new(), - })), - } - } - - fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender) -> u64 { - let mut tbl = self.connections.lock().unwrap(); - let id = tbl.next_id; - tbl.next_id += 1; - tbl.connections.insert( - id, - Connection { - connected: Some(connected), - data, - }, - ); - id - } - - fn add(self: &mut Self, id: u64, data: mpsc::Sender) { - let mut tbl = self.connections.lock().unwrap(); - tbl.connections.insert( - id, - Connection { - connected: None, - data, - }, - ); - } - - fn connected(self: &mut Self, id: u64) { - let connected = { - let mut tbl = self.connections.lock().unwrap(); - if let Some(c) = tbl.connections.get_mut(&id) { - c.connected.take() - } else { - None - } - }; - - if let Some(connected) = connected { - _ = connected.send(()); - } - } - - async fn receive(self: &Self, id: u64, buf: Bytes) { - let data = { - let tbl = self.connections.lock().unwrap(); - if let Some(connection) = tbl.connections.get(&id) { - Some(connection.data.clone()) - } else { - None - } - }; - - if let Some(data) = data { - _ = data.send(buf).await; - } - } - - fn remove(self: &mut Self, id: u64) { - let mut tbl = self.connections.lock().unwrap(); - tbl.connections.remove(&id); - } -} - async fn server_handle_connection( channel: u64, port: u16, @@ -285,6 +218,10 @@ async fn server_main( } async fn client_sync(reader: &mut T) -> Result<(), Error> { + // TODO: While we're waiting here we should be echoing everything we read. + // We should also be proxying *our* stdin to the processes stdin, + // and turn that off when we've synchronized. That way we can + // handle passwords and the like for authentication. eprintln!("> Waiting for synchronization marker..."); let mut seen = 0; while seen < 8 {