diff --git a/src/main.rs b/src/main.rs index 0d82cca..76c5ccb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,9 @@ mod refresh; use message::Message; +// ---------------------------------------------------------------------------- +// Message Writing + struct MessageWriter { writer: T, } @@ -26,7 +29,7 @@ impl MessageWriter { } async fn write(self: &mut Self, msg: Message) -> Result<(), Error> { // TODO: Optimize buffer usage please this is bad - eprintln!("? {:?}", msg); + // eprintln!("? {:?}", msg); let mut buffer = msg.encode(); self.writer .write_u32(buffer.len().try_into().expect("Message too large")) @@ -47,8 +50,81 @@ async fn pump_write( Ok(()) } +// ---------------------------------------------------------------------------- +// Connection + +/// Read from a socket and convert the reads into Messages to put into the +/// queue until the socket is closed for reading or an error occurs. +async fn connection_read( + channel: u64, + read: &mut T, + writer: &mut mpsc::Sender, +) -> Result<(), Error> { + let result = loop { + let mut buffer = BytesMut::with_capacity(64 * 1024); + if let Err(e) = read.read_buf(&mut buffer).await { + break Err(e); + } + if buffer.len() == 0 { + break Ok(()); + } + + if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { + break Err(Error::from(ErrorKind::ConnectionReset)); + } + + // TODO: Flow control here, wait for the packet to be acknowleged so + // there isn't head-of-line blocking or infinite bufferingon the + // remote side. Also buffer re-use! + }; + + // We are effectively closed on this side, send the close to drop the + // corresponding write side on the other end of the pipe. + _ = writer.send(Message::Close(channel)).await; + return result; +} + +/// Get messages from a queue and write them out to a socket until there are +/// no more messages in the queue or the write breaks for some reason. +async fn connection_write( + data: &mut mpsc::Receiver, + write: &mut T, +) -> Result<(), Error> { + while let Some(buf) = data.recv().await { + write.write_all(&buf[..]).await?; + } + Ok(()) +} + +/// Handle a connection, from the socket to the multiplexer and from the +/// multiplexer to the socket. +async fn connection_process( + channel: u64, + stream: &mut TcpStream, + data: &mut mpsc::Receiver, + writer: &mut mpsc::Sender, +) { + let (mut read_half, mut write_half) = stream.split(); + + let read = connection_read(channel, &mut read_half, writer); + let write = connection_write(data, &mut write_half); + + tokio::pin!(read); + tokio::pin!(write); + + let (mut done_reading, mut done_writing) = (false, false); + while !(done_reading && done_writing) { + tokio::select! { + _ = &mut read, if !done_reading => { done_reading = true; }, + _ = &mut write, if !done_writing => { done_writing = true;}, + } + } +} + +// ---------------------------------------------------------------------------- +// Server + struct ServerConnection { - close: Option>, data: mpsc::Sender, } @@ -64,24 +140,9 @@ impl ServerConnectionTable { } } - fn add(self: &mut Self, id: u64, close: oneshot::Sender<()>, data: mpsc::Sender) { + fn add(self: &mut Self, id: u64, data: mpsc::Sender) { let mut connections = self.connections.lock().unwrap(); - connections.insert( - id, - ServerConnection { - close: Some(close), - data, - }, - ); - } - - fn close(self: &mut Self, id: u64) { - let mut connections = self.connections.lock().unwrap(); - if let Some(connection) = connections.get_mut(&id) { - if let Some(close) = connection.close.take() { - _ = close.send(()); - } - } + connections.insert(id, ServerConnection { data }); } async fn receive(self: &Self, id: u64, buf: Bytes) { @@ -105,16 +166,6 @@ impl ServerConnectionTable { } } -async fn server_connection_write( - data: &mut mpsc::Receiver, - write: &mut T, -) -> Result<(), Error> { - while let Some(buf) = data.recv().await { - write.write_all(&buf[..]).await?; - } - Ok(()) -} - async fn server_handle_connection( channel: u64, port: u16, @@ -123,23 +174,17 @@ async fn server_handle_connection( ) { let mut connections = connections; if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { - let (send_close, closed) = oneshot::channel(); let (send_data, mut data) = mpsc::channel(32); - connections.add(channel, send_close, send_data); + connections.add(channel, send_data); if let Ok(_) = writer.send(Message::Connected(channel)).await { - let (mut read_half, mut write_half) = stream.split(); + let mut writer = writer.clone(); + connection_process(channel, &mut stream, &mut data, &mut writer).await; - // TODO: Read until we get a close on `rx`. - - tokio::select! { - _ = client_connection_read(channel, &mut read_half, writer.clone()) => (), - _ = server_connection_write(&mut data, &mut write_half) => (), - _ = closed => (), - } + eprintln!("< Done server!"); } - connections.remove(channel); } + // Wrong! _ = writer.send(Message::Closed(channel)); } @@ -173,7 +218,11 @@ async fn server_read( Close(channel) => { let mut connections = connections.clone(); tokio::spawn(async move { - connections.close(channel); + // Once we get a close the connection becomes unreachable. + // + // NOTE: If all goes well the 'data' channel gets dropped + // here, and we close the write half of the socket. + connections.remove(channel); }); } Data(channel, buf) => { @@ -262,7 +311,6 @@ async fn client_sync(reader: &mut T) -> Result<(), Error> struct ClientConnection { connected: Option>, - closed: Option>, data: mpsc::Sender, } @@ -286,12 +334,7 @@ impl ClientConnectionTable { } } - fn alloc( - self: &mut Self, - connected: oneshot::Sender<()>, - closed: oneshot::Sender<()>, - data: mpsc::Sender, - ) -> u64 { + fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender) -> u64 { let mut tbl = self.connections.lock().unwrap(); let id = tbl.next_id; tbl.next_id += 1; @@ -299,28 +342,12 @@ impl ClientConnectionTable { id, ClientConnection { connected: Some(connected), - closed: Some(closed), data, }, ); id } - fn closed(self: &mut Self, id: u64) { - let closed = { - let mut tbl = self.connections.lock().unwrap(); - if let Some(c) = tbl.connections.get_mut(&id) { - c.closed.take() - } else { - None - } - }; - - if let Some(closed) = closed { - _ = closed.send(()); - } - } - fn connected(self: &mut Self, id: u64) { let connected = { let mut tbl = self.connections.lock().unwrap(); @@ -357,28 +384,6 @@ impl ClientConnectionTable { } } -async fn client_connection_read( - channel: u64, - read: &mut T, - writer: mpsc::Sender, -) -> Result<(), Error> { - loop { - let mut buffer = BytesMut::with_capacity(64 * 1024); - read.read_buf(&mut buffer).await?; - if buffer.len() == 0 { - return Ok(()); - } - - if let Err(_) = writer.send(Message::Data(channel, buffer.into())).await { - return Err(Error::from(ErrorKind::ConnectionReset)); - } - - // TODO: Flow control here, wait for the packet to be acknowleged so - // there isn't head-of-line blocking or infinite bufferingon the - // remote side. Also buffer re-use! - } -} - async fn client_handle_connection( port: u16, writer: mpsc::Sender, @@ -387,30 +392,19 @@ async fn client_handle_connection( ) { let mut connections = connections; let (send_connected, connected) = oneshot::channel(); - let (send_closed, mut closed) = oneshot::channel(); let (send_data, mut data) = mpsc::channel(32); - let channel_id = connections.alloc(send_connected, send_closed, send_data); + let channel = connections.alloc(send_connected, send_data); - if let Ok(_) = writer.send(Message::Connect(channel_id, port)).await { - let connected = tokio::select! { - _ = connected => true, - _ = &mut closed => false - }; + if let Ok(_) = writer.send(Message::Connect(channel, port)).await { + if let Ok(_) = connected.await { + let mut writer = writer.clone(); + connection_process(channel, socket, &mut data, &mut writer).await; - if connected { - let (mut read_half, mut write_half) = socket.split(); - tokio::select! { - _ = client_connection_read(channel_id, &mut read_half, writer.clone()) => (), - _ = server_connection_write(&mut data, &mut write_half) => (), - _ = closed => () - }; + eprintln!("> Done client!"); } else { eprintln!("> Failed to connect to remote"); } } - - connections.remove(channel_id); - _ = writer.send(Message::Close(channel_id)).await; } async fn client_listen( @@ -421,8 +415,8 @@ async fn client_listen( loop { let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; loop { - // The second item contains the IP and port of the new connection. - // TODO: Handle shutdown correctly. + // The second item contains the IP and port of the new + // connection, but we don't care. let (mut socket, _) = listener.accept().await?; let (writer, connections) = (writer.clone(), connections.clone()); @@ -465,7 +459,7 @@ async fn client_read( Close(channel) => { let mut connections = connections.clone(); tokio::spawn(async move { - connections.closed(channel); + connections.remove(channel); }); } Data(channel, buf) => { @@ -484,7 +478,12 @@ async fn client_read( let port = port.port; if let Some(l) = listeners.remove(&port) { if !l.is_closed() { - // Listen could have failed! + // `l` here is, of course, the channel that we + // use to tell the listener task to stop (see the + // spawn call below). If it isn't closed then + // that means a spawn task is still running so we + // should just let it keep running and re-use the + // existing listener. new_listeners.insert(port, l); } }