diff --git a/src/client/ui.rs b/src/client/ui.rs index 4ba2b04..15ea41c 100644 --- a/src/client/ui.rs +++ b/src/client/ui.rs @@ -127,6 +127,11 @@ impl Listener { } } + #[cfg(test)] + pub fn state_ref(&self) -> std::sync::Arc> { + self.state.clone() + } + pub fn enabled(&self) -> bool { self.state() == State::Enabled } @@ -154,8 +159,8 @@ impl Listener { *self.state.lock().unwrap() } - pub fn set_enabled(&mut self, socks_port: Option, enabled: bool) { - if enabled { + pub fn toggle_enabled(&mut self, socks_port: Option) { + if self.state() == State::Disabled { self.state = State::Enabled.boxed(); self.start(socks_port); } else { @@ -435,7 +440,7 @@ impl UI { fn enable_disable_port(&mut self, port: u16) { if let Some(listener) = self.ports.get_mut(&port) { - listener.set_enabled(self.socks_port, !listener.enabled()); + listener.toggle_enabled(self.socks_port); } } @@ -1204,4 +1209,46 @@ mod tests { drop(sender); } + + #[test] + fn state_toggle_enable_disable() { + let (sender, receiver) = mpsc::channel(64); + let config = ServerConfig::default(); + let mut ui = UI::new(receiver, config); + + ui.handle_internal_event(Some(UIEvent::Ports(vec![PortDesc { + port: 8080, + desc: "rando".to_string(), + }]))); + + let listener = ui.ports.get_mut(&8080).unwrap(); + assert_eq!(listener.state(), State::Enabled); + + // Enabled -> Disabled + ui.enable_disable_port(8080); // FLIP! + let listener = ui.ports.get(&8080).unwrap(); + assert_eq!(listener.state(), State::Disabled); + + // Disabled -> Enabled + ui.enable_disable_port(8080); // FLIP! + let listener = ui.ports.get(&8080).unwrap(); + assert_eq!(listener.state(), State::Enabled); + + { + // Oh no it broke! + let state = listener.state_ref(); + let mut sg = state.lock().unwrap(); + *sg = State::Broken; + } + + let listener = ui.ports.get_mut(&8080).unwrap(); + assert_eq!(listener.state(), State::Broken); + + // Broken -> Disabled + ui.enable_disable_port(8080); + let listener = ui.ports.get_mut(&8080).unwrap(); + assert_eq!(listener.state(), State::Disabled); + + drop(sender); + } }