diff --git a/src/connection.rs b/src/connection.rs index 2288d2b..9bf697d 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -36,6 +36,7 @@ async fn connection_read( if let Err(e) = read.read_buf(&mut buffer).await { break Err(Error::IO(e)); } + if buffer.len() == 0 { break Ok(()); } @@ -52,7 +53,7 @@ async fn connection_read( // We are effectively closed on this side, send the close to drop the // corresponding write side on the other end of the pipe. _ = writer.send(Message::Close(channel)).await; - return result; + result } /// Get messages from a queue and write them out to a socket until there are @@ -99,11 +100,14 @@ pub async fn process( while !(done_reading && done_writing) { tokio::select! { _ = &mut read, if !done_reading => { done_reading = true; }, - _ = &mut write, if !done_writing => { done_writing = true;}, + _ = &mut write, if !done_writing => { done_writing = true; }, } } } +// ---------------------------------------------------------------------------- +// Tables + /// The connection structure tracks the various channels used to communicate /// with an "open" connection. struct Connection { @@ -220,3 +224,100 @@ impl ConnectionTable { tbl.connections.remove(&id); } } + +#[cfg(test)] +mod tests { + use super::*; + use tokio::net::TcpListener; + + async fn create_connected_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + let connect = tokio::spawn(async move { + TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap() + }); + + let (server, _) = listener.accept().await.unwrap(); + let client = connect.await.unwrap(); + + (client, server) + } + + #[tokio::test] + async fn test_connected_pair() { + // This is just a sanity test to make sure my socket nonsense is working. + let (mut client, mut server) = create_connected_pair().await; + + let a = tokio::spawn(async move { + let mut d = vec![1, 2, 3]; + client.write_all(&mut d).await.unwrap(); + //eprintln!("Wrote something!"); + }); + + let b = tokio::spawn(async move { + let mut x = BytesMut::with_capacity(3); + server.read_buf(&mut x).await.unwrap(); + //panic!("Read {:?}", x); + }); + + a.await.unwrap(); + b.await.unwrap(); + } + + #[tokio::test] + async fn test_process_connection() { + let (mut client, mut server) = create_connected_pair().await; + + const CHID: u64 = 123; + let (mut msg_writer, mut msg_receiver) = mpsc::channel(32); + let (data_writer, mut data_receiver) = mpsc::channel(32); + + let proc = tokio::spawn(async move { + process(CHID, &mut server, &mut data_receiver, &mut msg_writer).await + }); + + // Any bytes I send through `data_writer` will come into my socket. + let packet = Bytes::from("hello world"); + data_writer.send(packet.clone()).await.unwrap(); + + let mut buffer = BytesMut::with_capacity(packet.len()); + buffer.resize(packet.len(), 0); + client.read_exact(&mut buffer).await.unwrap(); + assert_eq!(packet, buffer); + + // Any bytes I send through client come through on msg_receiver. + client.write_all(&packet[..]).await.unwrap(); + let msg = msg_receiver.recv().await.unwrap(); + assert_eq!(msg, Message::Data(CHID, packet.clone())); + + // When I close the write half of the socket then I get a close + // message. + let (mut read_half, mut write_half) = client.split(); + write_half.shutdown().await.unwrap(); + let msg = msg_receiver.recv().await.unwrap(); + assert_eq!(msg, Message::Close(CHID)); + + // I should still be able to use the read half of the socket. + let packet = Bytes::from("StIlL AlIvE"); + data_writer.send(packet.clone()).await.unwrap(); + + let mut buffer = BytesMut::with_capacity(packet.len()); + buffer.resize(packet.len(), 0); + read_half.read_exact(&mut buffer).await.unwrap(); + assert_eq!(packet, buffer); + + // When I drop the data writer my read half closes. + drop(data_writer); + let mut buffer = BytesMut::with_capacity(1024); + read_half.read_buf(&mut buffer).await.unwrap(); + assert_eq!(buffer.len(), 0); + + drop(read_half); + + // and the processing loop terminates. + proc.await.unwrap(); + } +}