Unify connection tables

They're mostly the same but the client side needs a different alloc
and the ability to signal connected.
This commit is contained in:
John Doty 2022-10-08 19:31:46 +00:00
parent 63a02a4211
commit bf4cdcfb6a

View file

@ -78,21 +78,43 @@ struct Connection {
data: mpsc::Sender<Bytes>, data: mpsc::Sender<Bytes>,
} }
#[derive(Clone)] struct ConnectionTableState {
struct ServerConnectionTable { next_id: u64,
connections: Arc<Mutex<HashMap<u64, Connection>>>, connections: HashMap<u64, Connection>,
} }
impl ServerConnectionTable { #[derive(Clone)]
fn new() -> ServerConnectionTable { struct ConnectionTable {
ServerConnectionTable { connections: Arc<Mutex<ConnectionTableState>>,
connections: Arc::new(Mutex::new(HashMap::new())),
} }
impl ConnectionTable {
fn new() -> ConnectionTable {
ConnectionTable {
connections: Arc::new(Mutex::new(ConnectionTableState {
next_id: 0,
connections: HashMap::new(),
})),
}
}
fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender<Bytes>) -> u64 {
let mut tbl = self.connections.lock().unwrap();
let id = tbl.next_id;
tbl.next_id += 1;
tbl.connections.insert(
id,
Connection {
connected: Some(connected),
data,
},
);
id
} }
fn add(self: &mut Self, id: u64, data: mpsc::Sender<Bytes>) { fn add(self: &mut Self, id: u64, data: mpsc::Sender<Bytes>) {
let mut connections = self.connections.lock().unwrap(); let mut tbl = self.connections.lock().unwrap();
connections.insert( tbl.connections.insert(
id, id,
Connection { Connection {
connected: None, connected: None,
@ -101,10 +123,25 @@ impl ServerConnectionTable {
); );
} }
fn connected(self: &mut Self, id: u64) {
let connected = {
let mut tbl = self.connections.lock().unwrap();
if let Some(c) = tbl.connections.get_mut(&id) {
c.connected.take()
} else {
None
}
};
if let Some(connected) = connected {
_ = connected.send(());
}
}
async fn receive(self: &Self, id: u64, buf: Bytes) { async fn receive(self: &Self, id: u64, buf: Bytes) {
let data = { let data = {
let connections = self.connections.lock().unwrap(); let tbl = self.connections.lock().unwrap();
if let Some(connection) = connections.get(&id) { if let Some(connection) = tbl.connections.get(&id) {
Some(connection.data.clone()) Some(connection.data.clone())
} else { } else {
None None
@ -117,8 +154,8 @@ impl ServerConnectionTable {
} }
fn remove(self: &mut Self, id: u64) { fn remove(self: &mut Self, id: u64) {
let mut connections = self.connections.lock().unwrap(); let mut tbl = self.connections.lock().unwrap();
connections.remove(&id); tbl.connections.remove(&id);
} }
} }
@ -126,7 +163,7 @@ async fn server_handle_connection(
channel: u64, channel: u64,
port: u16, port: u16,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ServerConnectionTable, connections: ConnectionTable,
) { ) {
let mut connections = connections; let mut connections = connections;
if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
@ -144,7 +181,7 @@ async fn server_handle_connection(
async fn server_read<T: AsyncRead + Unpin>( async fn server_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>, reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ServerConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<(), Error> {
eprintln!("< Processing packets..."); eprintln!("< Processing packets...");
loop { loop {
@ -200,7 +237,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>, reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>, writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let connections = ServerConnectionTable::new(); let connections = ConnectionTable::new();
// The first message we send must be an announcement. // The first message we send must be an announcement.
writer.write(Message::Hello(0, 1, vec![])).await?; writer.write(Message::Hello(0, 1, vec![])).await?;
@ -250,80 +287,10 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
Ok(()) Ok(())
} }
struct ClientConnectionTableState {
next_id: u64,
connections: HashMap<u64, Connection>,
}
#[derive(Clone)]
struct ClientConnectionTable {
connections: Arc<Mutex<ClientConnectionTableState>>,
}
impl ClientConnectionTable {
fn new() -> ClientConnectionTable {
ClientConnectionTable {
connections: Arc::new(Mutex::new(ClientConnectionTableState {
next_id: 0,
connections: HashMap::new(),
})),
}
}
fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender<Bytes>) -> u64 {
let mut tbl = self.connections.lock().unwrap();
let id = tbl.next_id;
tbl.next_id += 1;
tbl.connections.insert(
id,
Connection {
connected: Some(connected),
data,
},
);
id
}
fn connected(self: &mut Self, id: u64) {
let connected = {
let mut tbl = self.connections.lock().unwrap();
if let Some(c) = tbl.connections.get_mut(&id) {
c.connected.take()
} else {
None
}
};
if let Some(connected) = connected {
_ = connected.send(());
}
}
async fn receive(self: &Self, id: u64, buf: Bytes) {
let data = {
let tbl = self.connections.lock().unwrap();
if let Some(connection) = tbl.connections.get(&id) {
Some(connection.data.clone())
} else {
None
}
};
if let Some(data) = data {
_ = data.send(buf).await;
}
}
fn remove(self: &mut Self, id: u64) {
let mut tbl = self.connections.lock().unwrap();
tbl.connections.remove(&id);
}
}
async fn client_handle_connection( async fn client_handle_connection(
port: u16, port: u16,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ClientConnectionTable, connections: ConnectionTable,
socket: &mut TcpStream, socket: &mut TcpStream,
) { ) {
let mut connections = connections; let mut connections = connections;
@ -346,7 +313,7 @@ async fn client_handle_connection(
async fn client_listen( async fn client_listen(
port: u16, port: u16,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ClientConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<(), Error> {
loop { loop {
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await { let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
@ -372,7 +339,7 @@ async fn client_listen(
async fn client_read<T: AsyncRead + Unpin>( async fn client_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>, reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>, writer: mpsc::Sender<Message>,
connections: ClientConnectionTable, connections: ConnectionTable,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new(); let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
@ -462,7 +429,7 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
eprintln!("> Sending initial list command..."); eprintln!("> Sending initial list command...");
writer.write(Message::Refresh).await?; writer.write(Message::Refresh).await?;
let connections = ClientConnectionTable::new(); let connections = ConnectionTable::new();
// And now really get into it... // And now really get into it...
let (msg_sender, mut msg_receiver) = mpsc::channel(32); let (msg_sender, mut msg_receiver) = mpsc::channel(32);