From c5bf78fc715ea05b86265f32c6214d0488149999 Mon Sep 17 00:00:00 2001 From: John Doty Date: Fri, 7 Oct 2022 13:56:59 +0000 Subject: [PATCH] It's alive --- src/main.rs | 318 +++++++++++++++++++++++++++++++++++++++---------- src/message.rs | 34 +++--- 2 files changed, 271 insertions(+), 81 deletions(-) diff --git a/src/main.rs b/src/main.rs index fd4b740..0d82cca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::io::Cursor; use std::net::{Ipv4Addr, SocketAddrV4}; @@ -26,6 +26,7 @@ impl MessageWriter { } async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { // TODO: Optimize buffer usage please this is bad + eprintln!("? {:?}", msg); let mut buffer = msg.encode(); self.writer .write_u32(buffer.len().try_into().expect("Message too large")) @@ -46,25 +47,106 @@ async fn pump_write( Ok(()) } -async fn server_connect( - channel: u64, - port: u16, - writer: &mut mpsc::Sender, -) -> Result<(), Error> { - let _stream = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; - if let Err(e) = writer.send(Message::Connected(channel, port)).await { - eprintln!("< Warning: couldn't send Connected: {:?}", e); - return Err(Error::from(ErrorKind::BrokenPipe)); +struct ServerConnection { + close: Option>, + data: mpsc::Sender, +} + +#[derive(Clone)] +struct ServerConnectionTable { + connections: Arc>>, +} + +impl ServerConnectionTable { + fn new() -> ServerConnectionTable { + ServerConnectionTable { + connections: Arc::new(Mutex::new(HashMap::new())), + } } - // Do the thing, read and write and whatnot. + fn add(self: &mut Self, id: u64, close: oneshot::Sender<()>, data: mpsc::Sender) { + let mut connections = self.connections.lock().unwrap(); + connections.insert( + id, + ServerConnection { + close: Some(close), + data, + }, + ); + } + fn close(self: &mut Self, id: u64) { + let mut connections = self.connections.lock().unwrap(); + if let Some(connection) = connections.get_mut(&id) { + if let Some(close) = connection.close.take() { + _ = close.send(()); + } + } + } + + async fn receive(self: &Self, id: u64, buf: Bytes) { + let data = { + let connections = self.connections.lock().unwrap(); + if let Some(connection) = connections.get(&id) { + Some(connection.data.clone()) + } else { + None + } + }; + + if let Some(data) = data { + _ = data.send(buf).await; + } + } + + fn remove(self: &mut Self, id: u64) { + let mut connections = self.connections.lock().unwrap(); + connections.remove(&id); + } +} + +async fn server_connection_write( + data: &mut mpsc::Receiver, + write: &mut T, +) -> Result<(), Error> { + while let Some(buf) = data.recv().await { + write.write_all(&buf[..]).await?; + } Ok(()) } +async fn server_handle_connection( + channel: u64, + port: u16, + writer: mpsc::Sender, + connections: ServerConnectionTable, +) { + let mut connections = connections; + if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { + let (send_close, closed) = oneshot::channel(); + let (send_data, mut data) = mpsc::channel(32); + connections.add(channel, send_close, send_data); + if let Ok(_) = writer.send(Message::Connected(channel)).await { + let (mut read_half, mut write_half) = stream.split(); + + // TODO: Read until we get a close on `rx`. + + tokio::select! { + _ = client_connection_read(channel, &mut read_half, writer.clone()) => (), + _ = server_connection_write(&mut data, &mut write_half) => (), + _ = closed => (), + } + } + connections.remove(channel); + } + + _ = writer.send(Message::Closed(channel)); +} + async fn server_read( reader: &mut T, writer: mpsc::Sender, + connections: ServerConnectionTable, ) -> Result<(), Error> { eprintln!("< Processing packets..."); loop { @@ -83,13 +165,21 @@ async fn server_read( match message { Ping => (), Connect(channel, port) => { - let writer = writer.clone(); + let (writer, connections) = (writer.clone(), connections.clone()); tokio::spawn(async move { - let mut writer = writer; - if let Err(e) = server_connect(channel, port, &mut writer).await { - eprintln!("< Connection failed: {:?}", e); - _ = writer.send(Message::Abort(channel)).await; - } + server_handle_connection(channel, port, writer, connections).await; + }); + } + Close(channel) => { + let mut connections = connections.clone(); + tokio::spawn(async move { + connections.close(channel); + }); + } + Data(channel, buf) => { + let connections = connections.clone(); + tokio::spawn(async move { + connections.receive(channel, buf).await; }); } Refresh => { @@ -117,10 +207,12 @@ async fn server_main( reader: &mut Reader, writer: &mut MessageWriter, ) -> Result<(), Error> { + let connections = ServerConnectionTable::new(); + // Jump into it... let (msg_sender, mut msg_receiver) = mpsc::channel(32); let writing = pump_write(&mut msg_receiver, writer); - let reading = server_read(reader, msg_sender); + let reading = server_read(reader, msg_sender, connections); tokio::pin!(reading); tokio::pin!(writing); @@ -169,76 +261,173 @@ async fn client_sync(reader: &mut T) -> Result<(), Error> } struct ClientConnection { - connected: Option>>, + connected: Option>, + closed: Option>, + data: mpsc::Sender, } -struct ClientConnectionTable { +struct ClientConnectionTableState { next_id: u64, connections: HashMap, } +#[derive(Clone)] +struct ClientConnectionTable { + connections: Arc>, +} + impl ClientConnectionTable { fn new() -> ClientConnectionTable { ClientConnectionTable { - next_id: 0, - connections: HashMap::new(), + connections: Arc::new(Mutex::new(ClientConnectionTableState { + next_id: 0, + connections: HashMap::new(), + })), } } -} -type ClientConnections = Arc>; - -async fn client_handle_connection( - port: u16, - writer: mpsc::Sender, - connections: ClientConnections, -) { - let (connected, rx) = oneshot::channel(); - let channel_id = { - let mut tbl = connections.lock().unwrap(); + fn alloc( + self: &mut Self, + connected: oneshot::Sender<()>, + closed: oneshot::Sender<()>, + data: mpsc::Sender, + ) -> u64 { + let mut tbl = self.connections.lock().unwrap(); let id = tbl.next_id; tbl.next_id += 1; tbl.connections.insert( id, ClientConnection { connected: Some(connected), + closed: Some(closed), + data, }, ); id - }; + } - if let Ok(_) = writer.send(Message::Connect(channel_id, port)).await { - if let Ok(r) = rx.await { - if let Ok(_) = r { - // Connection worked! Do the damn thing. - eprintln!("Got here I guess! {}", channel_id); + fn closed(self: &mut Self, id: u64) { + let closed = { + let mut tbl = self.connections.lock().unwrap(); + if let Some(c) = tbl.connections.get_mut(&id) { + c.closed.take() + } else { + None } + }; + + if let Some(closed) = closed { + _ = closed.send(()); } } - { - let mut tbl = connections.lock().unwrap(); - tbl.connections.remove(&channel_id); + fn connected(self: &mut Self, id: u64) { + let connected = { + let mut tbl = self.connections.lock().unwrap(); + if let Some(c) = tbl.connections.get_mut(&id) { + c.connected.take() + } else { + None + } + }; + + if let Some(connected) = connected { + _ = connected.send(()); + } } - // If the writer is closed then the whole connection is closed. + + async fn receive(self: &Self, id: u64, buf: Bytes) { + let data = { + let tbl = self.connections.lock().unwrap(); + if let Some(connection) = tbl.connections.get(&id) { + Some(connection.data.clone()) + } else { + None + } + }; + + if let Some(data) = data { + _ = data.send(buf).await; + } + } + + fn remove(self: &mut Self, id: u64) { + let mut tbl = self.connections.lock().unwrap(); + tbl.connections.remove(&id); + } +} + +async fn client_connection_read( + channel: u64, + read: &mut T, + writer: mpsc::Sender, +) -> Result<(), Error> { + loop { + let mut buffer = BytesMut::with_capacity(64 * 1024); + read.read_buf(&mut buffer).await?; + if buffer.len() == 0 { + return Ok(()); + } + + if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { + return Err(Error::from(ErrorKind::ConnectionReset)); + } + + // TODO: Flow control here, wait for the packet to be acknowleged so + // there isn't head-of-line blocking or infinite bufferingon the + // remote side. Also buffer re-use! + } +} + +async fn client_handle_connection( + port: u16, + writer: mpsc::Sender, + connections: ClientConnectionTable, + socket: &mut TcpStream, +) { + let mut connections = connections; + let (send_connected, connected) = oneshot::channel(); + let (send_closed, mut closed) = oneshot::channel(); + let (send_data, mut data) = mpsc::channel(32); + let channel_id = connections.alloc(send_connected, send_closed, send_data); + + if let Ok(_) = writer.send(Message::Connect(channel_id, port)).await { + let connected = tokio::select! { + _ = connected => true, + _ = &mut closed => false + }; + + if connected { + let (mut read_half, mut write_half) = socket.split(); + tokio::select! { + _ = client_connection_read(channel_id, &mut read_half, writer.clone()) => (), + _ = server_connection_write(&mut data, &mut write_half) => (), + _ = closed => () + }; + } else { + eprintln!("> Failed to connect to remote"); + } + } + + connections.remove(channel_id); _ = writer.send(Message::Close(channel_id)).await; } async fn client_listen( port: u16, writer: mpsc::Sender, - connections: ClientConnections, + connections: ClientConnectionTable, ) -> Result<(), Error> { loop { let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; loop { // The second item contains the IP and port of the new connection. // TODO: Handle shutdown correctly. - let (_, _) = listener.accept().await?; + let (mut socket, _) = listener.accept().await?; let (writer, connections) = (writer.clone(), connections.clone()); tokio::spawn(async move { - client_handle_connection(port, writer, connections).await; + client_handle_connection(port, writer, connections, &mut socket).await; }); } } @@ -246,8 +435,8 @@ async fn client_listen( async fn client_read( reader: &mut T, - connections: ClientConnections, writer: mpsc::Sender, + connections: ClientConnectionTable, ) -> Result<(), Error> { let mut listeners: HashMap> = HashMap::new(); @@ -267,20 +456,23 @@ async fn client_read( use Message::*; match message { Ping => (), - Connected(channel, _) => { - let connected = { - let mut tbl = connections.lock().unwrap(); - if let Some(c) = tbl.connections.get_mut(&channel) { - c.connected.take() - } else { - None - } - }; - - if let Some(connected) = connected { - // If we can't send the notification then... uh... ok? - _ = connected.send(Ok(())); - } + Connected(channel) => { + let mut connections = connections.clone(); + tokio::spawn(async move { + connections.connected(channel); + }); + } + Close(channel) => { + let mut connections = connections.clone(); + tokio::spawn(async move { + connections.closed(channel); + }); + } + Data(channel, buf) => { + let connections = connections.clone(); + tokio::spawn(async move { + connections.receive(channel, buf).await; + }); } Ports(ports) => { let mut new_listeners = HashMap::new(); @@ -333,12 +525,12 @@ async fn client_main( eprintln!("> Sending initial list command..."); writer.write(Message::Refresh).await?; - let connections = Arc::new(Mutex::new(ClientConnectionTable::new())); + let connections = ClientConnectionTable::new(); // And now really get into it... let (msg_sender, mut msg_receiver) = mpsc::channel(32); let writing = pump_write(&mut msg_receiver, writer); - let reading = client_read(reader, connections, msg_sender); + let reading = client_read(reader, msg_sender, connections); tokio::pin!(reading); tokio::pin!(writing); diff --git a/src/message.rs b/src/message.rs index f0f5801..95be271 100644 --- a/src/message.rs +++ b/src/message.rs @@ -17,10 +17,10 @@ pub struct PortDesc { #[derive(Debug, PartialEq)] pub enum Message { Ping, - Connect(u64, u16), // Request to connect on a port from client to server. - Connected(u64, u16), // Sucessfully connected from server to client. - Close(u64), // Request to close from client to server. - Abort(u64), // Notify of close from server to client. + Connect(u64, u16), // Request to connect on a port from client to server. + Connected(u64), // Sucessfully connected from server to client. + Close(u64), // Request to close connection on either end. + // Abort(u64), // Notify of close from server to client. Closed(u64), // Response to Close or Abort. Refresh, // Request to refresh list of ports from client. Ports(Vec), // List of available ports from server to client. @@ -40,19 +40,18 @@ impl Message { result.put_u64(*channel); result.put_u16(*port); } - Connected(channel, port) => { + Connected(channel) => { result.put_u8(0x02); result.put_u64(*channel); - result.put_u16(*port); } Close(channel) => { result.put_u8(0x03); result.put_u64(*channel); } - Abort(channel) => { - result.put_u8(0x04); - result.put_u64(*channel); - } + // Abort(channel) => { + // result.put_u8(0x04); + // result.put_u64(*channel); + // } Closed(channel) => { result.put_u8(0x05); result.put_u64(*channel); @@ -93,17 +92,16 @@ impl Message { } 0x02 => { let channel = get_u64(cursor)?; - let port = get_u16(cursor)?; - Ok(Connected(channel, port)) + Ok(Connected(channel)) } 0x03 => { let channel = get_u64(cursor)?; Ok(Close(channel)) } - 0x04 => { - let channel = get_u64(cursor)?; - Ok(Abort(channel)) - } + // 0x04 => { + // let channel = get_u64(cursor)?; + // Ok(Abort(channel)) + // } 0x05 => { let channel = get_u64(cursor)?; Ok(Closed(channel)) @@ -155,9 +153,9 @@ mod message_tests { fn round_trip() { assert_round_trip(Ping); assert_round_trip(Connect(0x1234567890123456, 0x1234)); - assert_round_trip(Connected(0x1234567890123456, 0x1234)); + assert_round_trip(Connected(0x1234567890123456)); assert_round_trip(Close(0x1234567890123456)); - assert_round_trip(Abort(0x1234567890123456)); + // assert_round_trip(Abort(0x1234567890123456)); assert_round_trip(Closed(0x1234567890123456)); assert_round_trip(Refresh); assert_round_trip(Ports(vec![