Doc doc
This commit is contained in:
parent
2faed6267e
commit
b98d28bd90
2 changed files with 146 additions and 89 deletions
|
|
@ -1,9 +1,12 @@
|
||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
const MAX_PACKET: usize = u16::max_value() as usize;
|
const MAX_PACKET: usize = u16::max_value() as usize;
|
||||||
|
|
||||||
|
|
@ -42,7 +45,7 @@ async fn connection_read<T: AsyncRead + Unpin>(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Flow control here, wait for the packet to be acknowleged so
|
// TODO: Flow control here, wait for the packet to be acknowleged so
|
||||||
// there isn't head-of-line blocking or infinite bufferingon the
|
// there isn't head-of-line blocking or infinite buffering on the
|
||||||
// remote side. Also buffer re-use!
|
// remote side. Also buffer re-use!
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -100,3 +103,120 @@ pub async fn process(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The connection structure tracks the various channels used to communicate
|
||||||
|
/// with an "open" connection.
|
||||||
|
struct Connection {
|
||||||
|
/// The callback for the connected message, if we haven't already
|
||||||
|
/// connected across the channel. Realistically, this only ever has a
|
||||||
|
/// value on the client side, where we wait for the server side to
|
||||||
|
/// connect and then acknowlege that the connection.
|
||||||
|
connected: Option<oneshot::Sender<()>>,
|
||||||
|
|
||||||
|
/// The channel where the connection receives [Bytes] to be written to
|
||||||
|
/// the socket.
|
||||||
|
data: mpsc::Sender<Bytes>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConnectionTableState {
|
||||||
|
next_id: u64,
|
||||||
|
connections: HashMap<u64, Connection>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tracking structure for connections. This structure is thread-safe and
|
||||||
|
/// so can be used to track new connections from as many concurrent listeners
|
||||||
|
/// as you would like.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ConnectionTable {
|
||||||
|
connections: Arc<Mutex<ConnectionTableState>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionTable {
|
||||||
|
/// Create a new, empty connection table.
|
||||||
|
pub fn new() -> ConnectionTable {
|
||||||
|
ConnectionTable {
|
||||||
|
connections: Arc::new(Mutex::new(ConnectionTableState {
|
||||||
|
next_id: 0,
|
||||||
|
connections: HashMap::new(),
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocate a new connection on the client side. The connection is
|
||||||
|
/// assigned a new ID, which is returned to the caller.
|
||||||
|
pub fn alloc(
|
||||||
|
self: &mut Self,
|
||||||
|
connected: oneshot::Sender<()>,
|
||||||
|
data: mpsc::Sender<Bytes>,
|
||||||
|
) -> u64 {
|
||||||
|
let mut tbl = self.connections.lock().unwrap();
|
||||||
|
let id = tbl.next_id;
|
||||||
|
tbl.next_id += 1;
|
||||||
|
tbl.connections.insert(
|
||||||
|
id,
|
||||||
|
Connection {
|
||||||
|
connected: Some(connected),
|
||||||
|
data,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a connection to the table on the server side. The client sent us
|
||||||
|
/// the ID to use, so we don't need to allocate it, and obviously we
|
||||||
|
/// aren't going to be waiting for the connection to be "connected."
|
||||||
|
pub fn add(self: &mut Self, id: u64, data: mpsc::Sender<Bytes>) {
|
||||||
|
let mut tbl = self.connections.lock().unwrap();
|
||||||
|
tbl.connections.insert(
|
||||||
|
id,
|
||||||
|
Connection {
|
||||||
|
connected: None,
|
||||||
|
data,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark a connection as being "connected", on the client side, where we
|
||||||
|
/// wait for the server to tell us such things. Note that this gets used
|
||||||
|
/// for a successful connection; on a failure just call [remove].
|
||||||
|
pub fn connected(self: &mut Self, id: u64) {
|
||||||
|
let connected = {
|
||||||
|
let mut tbl = self.connections.lock().unwrap();
|
||||||
|
if let Some(c) = tbl.connections.get_mut(&id) {
|
||||||
|
c.connected.take()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(connected) = connected {
|
||||||
|
_ = connected.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell a connection that we have received data. This gets used on both
|
||||||
|
/// sides of the pipe; if the connection exists and is still active it
|
||||||
|
/// will send the data out through its socket.
|
||||||
|
pub async fn receive(self: &Self, id: u64, buf: Bytes) {
|
||||||
|
let data = {
|
||||||
|
let tbl = self.connections.lock().unwrap();
|
||||||
|
if let Some(connection) = tbl.connections.get(&id) {
|
||||||
|
Some(connection.data.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(data) = data {
|
||||||
|
_ = data.send(buf).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a connection from the table, effectively closing it. This will
|
||||||
|
/// close all the pipes that the connection uses to receive data from the
|
||||||
|
/// other side, performing a cleanup on our "write" side of the socket.
|
||||||
|
pub fn remove(self: &mut Self, id: u64) {
|
||||||
|
let mut tbl = self.connections.lock().unwrap();
|
||||||
|
tbl.connections.remove(&id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
113
src/lib.rs
113
src/lib.rs
|
|
@ -1,8 +1,7 @@
|
||||||
use bytes::Bytes;
|
use connection::ConnectionTable;
|
||||||
use message::{Message, MessageReader, MessageWriter};
|
use message::{Message, MessageReader, MessageWriter};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::{Ipv4Addr, SocketAddrV4};
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::process;
|
use tokio::process;
|
||||||
|
|
@ -70,6 +69,26 @@ impl PartialEq for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Write Management
|
||||||
|
|
||||||
|
/// Gathers writes from an mpsc queue and writes them to the specified
|
||||||
|
/// writer.
|
||||||
|
///
|
||||||
|
/// This is kind of an odd function. It raises a lot of questions.
|
||||||
|
///
|
||||||
|
/// *Why can't this just be a wrapper function on top of MessageWriter that
|
||||||
|
/// everybody calls?* Well, we could do that, but we also need to synchronize
|
||||||
|
/// writes to the underlying stream.
|
||||||
|
///
|
||||||
|
/// *Why not use an async mutex?* Because this function has a nice side
|
||||||
|
/// benefit: if it ever quits, we're *either* doing an orderly shutdown
|
||||||
|
/// (because the last write end of this channel closed) *or* the remote
|
||||||
|
/// connection has closed. [client_main] uses this fact to its advantage to
|
||||||
|
/// detect when the connection has failed.
|
||||||
|
///
|
||||||
|
/// At some point we may even automatically reconnect in response!
|
||||||
|
///
|
||||||
async fn pump_write<T: AsyncWrite + Unpin>(
|
async fn pump_write<T: AsyncWrite + Unpin>(
|
||||||
messages: &mut mpsc::Receiver<Message>,
|
messages: &mut mpsc::Receiver<Message>,
|
||||||
writer: &mut MessageWriter<T>,
|
writer: &mut MessageWriter<T>,
|
||||||
|
|
@ -83,92 +102,6 @@ async fn pump_write<T: AsyncWrite + Unpin>(
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Server
|
// Server
|
||||||
|
|
||||||
struct Connection {
|
|
||||||
connected: Option<oneshot::Sender<()>>,
|
|
||||||
data: mpsc::Sender<Bytes>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConnectionTableState {
|
|
||||||
next_id: u64,
|
|
||||||
connections: HashMap<u64, Connection>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ConnectionTable {
|
|
||||||
connections: Arc<Mutex<ConnectionTableState>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionTable {
|
|
||||||
fn new() -> ConnectionTable {
|
|
||||||
ConnectionTable {
|
|
||||||
connections: Arc::new(Mutex::new(ConnectionTableState {
|
|
||||||
next_id: 0,
|
|
||||||
connections: HashMap::new(),
|
|
||||||
})),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender<Bytes>) -> u64 {
|
|
||||||
let mut tbl = self.connections.lock().unwrap();
|
|
||||||
let id = tbl.next_id;
|
|
||||||
tbl.next_id += 1;
|
|
||||||
tbl.connections.insert(
|
|
||||||
id,
|
|
||||||
Connection {
|
|
||||||
connected: Some(connected),
|
|
||||||
data,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add(self: &mut Self, id: u64, data: mpsc::Sender<Bytes>) {
|
|
||||||
let mut tbl = self.connections.lock().unwrap();
|
|
||||||
tbl.connections.insert(
|
|
||||||
id,
|
|
||||||
Connection {
|
|
||||||
connected: None,
|
|
||||||
data,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn connected(self: &mut Self, id: u64) {
|
|
||||||
let connected = {
|
|
||||||
let mut tbl = self.connections.lock().unwrap();
|
|
||||||
if let Some(c) = tbl.connections.get_mut(&id) {
|
|
||||||
c.connected.take()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(connected) = connected {
|
|
||||||
_ = connected.send(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn receive(self: &Self, id: u64, buf: Bytes) {
|
|
||||||
let data = {
|
|
||||||
let tbl = self.connections.lock().unwrap();
|
|
||||||
if let Some(connection) = tbl.connections.get(&id) {
|
|
||||||
Some(connection.data.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(data) = data {
|
|
||||||
_ = data.send(buf).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn remove(self: &mut Self, id: u64) {
|
|
||||||
let mut tbl = self.connections.lock().unwrap();
|
|
||||||
tbl.connections.remove(&id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn server_handle_connection(
|
async fn server_handle_connection(
|
||||||
channel: u64,
|
channel: u64,
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
@ -285,6 +218,10 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> {
|
async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error> {
|
||||||
|
// TODO: While we're waiting here we should be echoing everything we read.
|
||||||
|
// We should also be proxying *our* stdin to the processes stdin,
|
||||||
|
// and turn that off when we've synchronized. That way we can
|
||||||
|
// handle passwords and the like for authentication.
|
||||||
eprintln!("> Waiting for synchronization marker...");
|
eprintln!("> Waiting for synchronization marker...");
|
||||||
let mut seen = 0;
|
let mut seen = 0;
|
||||||
while seen < 8 {
|
while seen < 8 {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue