diff --git a/src/lib.rs b/src/lib.rs index a72004d..bbd1a96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,21 +78,43 @@ struct Connection { data: mpsc::Sender, } -#[derive(Clone)] -struct ServerConnectionTable { - connections: Arc>>, +struct ConnectionTableState { + next_id: u64, + connections: HashMap, } -impl ServerConnectionTable { - fn new() -> ServerConnectionTable { - ServerConnectionTable { - connections: Arc::new(Mutex::new(HashMap::new())), +#[derive(Clone)] +struct ConnectionTable { + connections: Arc>, +} + +impl ConnectionTable { + fn new() -> ConnectionTable { + ConnectionTable { + connections: Arc::new(Mutex::new(ConnectionTableState { + next_id: 0, + connections: HashMap::new(), + })), } } + fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender) -> u64 { + let mut tbl = self.connections.lock().unwrap(); + let id = tbl.next_id; + tbl.next_id += 1; + tbl.connections.insert( + id, + Connection { + connected: Some(connected), + data, + }, + ); + id + } + fn add(self: &mut Self, id: u64, data: mpsc::Sender) { - let mut connections = self.connections.lock().unwrap(); - connections.insert( + let mut tbl = self.connections.lock().unwrap(); + tbl.connections.insert( id, Connection { connected: None, @@ -101,10 +123,25 @@ impl ServerConnectionTable { ); } + fn connected(self: &mut Self, id: u64) { + let connected = { + let mut tbl = self.connections.lock().unwrap(); + if let Some(c) = tbl.connections.get_mut(&id) { + c.connected.take() + } else { + None + } + }; + + if let Some(connected) = connected { + _ = connected.send(()); + } + } + async fn receive(self: &Self, id: u64, buf: Bytes) { let data = { - let connections = self.connections.lock().unwrap(); - if let Some(connection) = connections.get(&id) { + let tbl = self.connections.lock().unwrap(); + if let Some(connection) = tbl.connections.get(&id) { Some(connection.data.clone()) } else { None @@ -117,8 +154,8 @@ impl ServerConnectionTable { } fn remove(self: &mut Self, id: u64) { - let mut connections = self.connections.lock().unwrap(); - connections.remove(&id); + let mut tbl = self.connections.lock().unwrap(); + tbl.connections.remove(&id); } } @@ -126,7 +163,7 @@ async fn server_handle_connection( channel: u64, port: u16, writer: mpsc::Sender, - connections: ServerConnectionTable, + connections: ConnectionTable, ) { let mut connections = connections; if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { @@ -144,7 +181,7 @@ async fn server_handle_connection( async fn server_read( reader: &mut MessageReader, writer: mpsc::Sender, - connections: ServerConnectionTable, + connections: ConnectionTable, ) -> Result<(), Error> { eprintln!("< Processing packets..."); loop { @@ -200,7 +237,7 @@ async fn server_main( reader: &mut MessageReader, writer: &mut MessageWriter, ) -> Result<(), Error> { - let connections = ServerConnectionTable::new(); + let connections = ConnectionTable::new(); // The first message we send must be an announcement. writer.write(Message::Hello(0, 1, vec![])).await?; @@ -250,80 +287,10 @@ async fn client_sync(reader: &mut T) -> Result<(), Error> Ok(()) } -struct ClientConnectionTableState { - next_id: u64, - connections: HashMap, -} - -#[derive(Clone)] -struct ClientConnectionTable { - connections: Arc>, -} - -impl ClientConnectionTable { - fn new() -> ClientConnectionTable { - ClientConnectionTable { - connections: Arc::new(Mutex::new(ClientConnectionTableState { - next_id: 0, - connections: HashMap::new(), - })), - } - } - - fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender) -> u64 { - let mut tbl = self.connections.lock().unwrap(); - let id = tbl.next_id; - tbl.next_id += 1; - tbl.connections.insert( - id, - Connection { - connected: Some(connected), - data, - }, - ); - id - } - - fn connected(self: &mut Self, id: u64) { - let connected = { - let mut tbl = self.connections.lock().unwrap(); - if let Some(c) = tbl.connections.get_mut(&id) { - c.connected.take() - } else { - None - } - }; - - if let Some(connected) = connected { - _ = connected.send(()); - } - } - - async fn receive(self: &Self, id: u64, buf: Bytes) { - let data = { - let tbl = self.connections.lock().unwrap(); - if let Some(connection) = tbl.connections.get(&id) { - Some(connection.data.clone()) - } else { - None - } - }; - - if let Some(data) = data { - _ = data.send(buf).await; - } - } - - fn remove(self: &mut Self, id: u64) { - let mut tbl = self.connections.lock().unwrap(); - tbl.connections.remove(&id); - } -} - async fn client_handle_connection( port: u16, writer: mpsc::Sender, - connections: ClientConnectionTable, + connections: ConnectionTable, socket: &mut TcpStream, ) { let mut connections = connections; @@ -346,7 +313,7 @@ async fn client_handle_connection( async fn client_listen( port: u16, writer: mpsc::Sender, - connections: ClientConnectionTable, + connections: ConnectionTable, ) -> Result<(), Error> { loop { let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { @@ -372,7 +339,7 @@ async fn client_listen( async fn client_read( reader: &mut MessageReader, writer: mpsc::Sender, - connections: ClientConnectionTable, + connections: ConnectionTable, ) -> Result<(), Error> { let mut listeners: HashMap> = HashMap::new(); @@ -462,7 +429,7 @@ async fn client_main( eprintln!("> Sending initial list command..."); writer.write(Message::Refresh).await?; - let connections = ClientConnectionTable::new(); + let connections = ConnectionTable::new(); // And now really get into it... let (msg_sender, mut msg_receiver) = mpsc::channel(32);