Unify connection tables

They're mostly the same but the client side needs a different alloc
and the ability to signal connected.
This commit is contained in:
John Doty 2022-10-08 19:31:46 +00:00
parent 63a02a4211
commit bf4cdcfb6a

View file

@ -78,21 +78,43 @@ struct Connection {
data: mpsc::Sender<Bytes>,
}
#[derive(Clone)]
struct ServerConnectionTable {
connections: Arc<Mutex<HashMap<u64, Connection>>>,
struct ConnectionTableState {
next_id: u64,
connections: HashMap<u64, Connection>,
}
impl ServerConnectionTable {
fn new() -> ServerConnectionTable {
ServerConnectionTable {
connections: Arc::new(Mutex::new(HashMap::new())),
#[derive(Clone)]
struct ConnectionTable {
connections: Arc<Mutex<ConnectionTableState>>,
}
impl ConnectionTable {
fn new() -> ConnectionTable {
ConnectionTable {
connections: Arc::new(Mutex::new(ConnectionTableState {
next_id: 0,
connections: HashMap::new(),
})),
}
}
fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender<Bytes>) -> u64 {
let mut tbl = self.connections.lock().unwrap();
let id = tbl.next_id;
tbl.next_id += 1;
tbl.connections.insert(
id,
Connection {
connected: Some(connected),
data,
},
);
id
}
fn add(self: &mut Self, id: u64, data: mpsc::Sender<Bytes>) {
let mut connections = self.connections.lock().unwrap();
connections.insert(
let mut tbl = self.connections.lock().unwrap();
tbl.connections.insert(
id,
Connection {
connected: None,
@ -101,10 +123,25 @@ impl ServerConnectionTable {
);
}
fn connected(self: &mut Self, id: u64) {
let connected = {
let mut tbl = self.connections.lock().unwrap();
if let Some(c) = tbl.connections.get_mut(&id) {
c.connected.take()
} else {
None
}
};
if let Some(connected) = connected {
_ = connected.send(());
}
}
async fn receive(self: &Self, id: u64, buf: Bytes) {
let data = {
let connections = self.connections.lock().unwrap();
if let Some(connection) = connections.get(&id) {
let tbl = self.connections.lock().unwrap();
if let Some(connection) = tbl.connections.get(&id) {
Some(connection.data.clone())
} else {
None
@ -117,8 +154,8 @@ impl ServerConnectionTable {
}
fn remove(self: &mut Self, id: u64) {
let mut connections = self.connections.lock().unwrap();
connections.remove(&id);
let mut tbl = self.connections.lock().unwrap();
tbl.connections.remove(&id);
}
}
@ -126,7 +163,7 @@ async fn server_handle_connection(
channel: u64,
port: u16,
writer: mpsc::Sender<Message>,
connections: ServerConnectionTable,
connections: ConnectionTable,
) {
let mut connections = connections;
if let Ok(mut stream) = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
@ -144,7 +181,7 @@ async fn server_handle_connection(
async fn server_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>,
connections: ServerConnectionTable,
connections: ConnectionTable,
) -> Result<(), Error> {
eprintln!("< Processing packets...");
loop {
@ -200,7 +237,7 @@ async fn server_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
reader: &mut MessageReader<Reader>,
writer: &mut MessageWriter<Writer>,
) -> Result<(), Error> {
let connections = ServerConnectionTable::new();
let connections = ConnectionTable::new();
// The first message we send must be an announcement.
writer.write(Message::Hello(0, 1, vec![])).await?;
@ -250,80 +287,10 @@ async fn client_sync<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(), Error>
Ok(())
}
struct ClientConnectionTableState {
next_id: u64,
connections: HashMap<u64, Connection>,
}
#[derive(Clone)]
struct ClientConnectionTable {
connections: Arc<Mutex<ClientConnectionTableState>>,
}
impl ClientConnectionTable {
fn new() -> ClientConnectionTable {
ClientConnectionTable {
connections: Arc::new(Mutex::new(ClientConnectionTableState {
next_id: 0,
connections: HashMap::new(),
})),
}
}
fn alloc(self: &mut Self, connected: oneshot::Sender<()>, data: mpsc::Sender<Bytes>) -> u64 {
let mut tbl = self.connections.lock().unwrap();
let id = tbl.next_id;
tbl.next_id += 1;
tbl.connections.insert(
id,
Connection {
connected: Some(connected),
data,
},
);
id
}
fn connected(self: &mut Self, id: u64) {
let connected = {
let mut tbl = self.connections.lock().unwrap();
if let Some(c) = tbl.connections.get_mut(&id) {
c.connected.take()
} else {
None
}
};
if let Some(connected) = connected {
_ = connected.send(());
}
}
async fn receive(self: &Self, id: u64, buf: Bytes) {
let data = {
let tbl = self.connections.lock().unwrap();
if let Some(connection) = tbl.connections.get(&id) {
Some(connection.data.clone())
} else {
None
}
};
if let Some(data) = data {
_ = data.send(buf).await;
}
}
fn remove(self: &mut Self, id: u64) {
let mut tbl = self.connections.lock().unwrap();
tbl.connections.remove(&id);
}
}
async fn client_handle_connection(
port: u16,
writer: mpsc::Sender<Message>,
connections: ClientConnectionTable,
connections: ConnectionTable,
socket: &mut TcpStream,
) {
let mut connections = connections;
@ -346,7 +313,7 @@ async fn client_handle_connection(
async fn client_listen(
port: u16,
writer: mpsc::Sender<Message>,
connections: ClientConnectionTable,
connections: ConnectionTable,
) -> Result<(), Error> {
loop {
let listener = match TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await {
@ -372,7 +339,7 @@ async fn client_listen(
async fn client_read<T: AsyncRead + Unpin>(
reader: &mut MessageReader<T>,
writer: mpsc::Sender<Message>,
connections: ClientConnectionTable,
connections: ConnectionTable,
) -> Result<(), Error> {
let mut listeners: HashMap<u16, oneshot::Sender<()>> = HashMap::new();
@ -462,7 +429,7 @@ async fn client_main<Reader: AsyncRead + Unpin, Writer: AsyncWrite + Unpin>(
eprintln!("> Sending initial list command...");
writer.write(Message::Refresh).await?;
let connections = ClientConnectionTable::new();
let connections = ConnectionTable::new();
// And now really get into it...
let (msg_sender, mut msg_receiver) = mpsc::channel(32);