diff --git a/Cargo.lock b/Cargo.lock index 9ced9a7..3b37c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,12 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +[[package]] +name = "bytes" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" + [[package]] name = "cc" version = "1.0.73" @@ -118,7 +124,18 @@ dependencies = [ name = "fwd" version = "0.1.0" dependencies = [ + "bytes", "procfs", + "tokio", +] + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", ] [[package]] @@ -173,6 +190,16 @@ version = "0.0.46" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4d2456c373231a208ad294c33dc5bff30051eafd954cd4caae83a712b12854d" +[[package]] +name = "lock_api" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.17" @@ -182,6 +209,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + [[package]] name = "miniz_oxide" version = "0.5.4" @@ -191,6 +224,18 @@ dependencies = [ "adler", ] +[[package]] +name = "mio" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -210,12 +255,51 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + [[package]] name = "proc-macro2" version = "1.0.46" @@ -249,6 +333,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + [[package]] name = "rustix" version = "0.35.11" @@ -263,6 +356,37 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "socket2" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "syn" version = "1.0.101" @@ -281,10 +405,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" dependencies = [ "libc", - "wasi", + "wasi 0.10.0+wasi-snapshot-preview1", "winapi", ] +[[package]] +name = "tokio" +version = "1.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.4" @@ -297,6 +452,12 @@ version = "0.10.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.83" diff --git a/Cargo.toml b/Cargo.toml index 06fd5b5..58e0876 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,4 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -procfs = "0.14.1" \ No newline at end of file +bytes = "1" +procfs = "0.14.1" +tokio = { version = "1", features = ["full"] } diff --git a/src/garbage.rs b/src/garbage.rs new file mode 100644 index 0000000..6e7881a --- /dev/null +++ b/src/garbage.rs @@ -0,0 +1,630 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use procfs::process::FDTarget; +use std::collections::HashMap; +use std::io::Cursor; +use std::net::{Ipv4Addr, SocketAddrV4}; +use std::sync::{Arc, Mutex}; +use tokio::io::{ + AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf, +}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; +use tokio::sync::oneshot; + +// ============================================================= +// Looking for listening ports +// ============================================================= + +struct PortDesc { + port: u16, + desc: String, +} + +fn get_entries() -> procfs::ProcResult> { + let all_procs = procfs::process::all_processes()?; + + // build up a map between socket inodes and process stat info. Ignore any + // error we encounter as it probably means we have no access to that + // process or something. + let mut map: HashMap = HashMap::new(); + for p in all_procs { + if let Ok(process) = p { + if !process.is_alive() { + continue; // Ignore zombies. + } + + if let (Ok(fds), Ok(cmd)) = (process.fd(), process.cmdline()) { + for fd in fds { + if let Ok(fd) = fd { + if let FDTarget::Socket(inode) = fd.target { + map.insert(inode, cmd.join(" ")); + } + } + } + } + } + } + + let mut h: HashMap = HashMap::new(); + + // Go through all the listening IPv4 and IPv6 sockets and take the first + // instance of listening on each port *if* the address is loopback or + // unspecified. (TODO: Do we want this restriction really?) + let tcp = procfs::net::tcp()?; + let tcp6 = procfs::net::tcp6()?; + for tcp_entry in tcp.into_iter().chain(tcp6) { + if tcp_entry.state == procfs::net::TcpState::Listen + && (tcp_entry.local_address.ip().is_loopback() + || tcp_entry.local_address.ip().is_unspecified()) + && !h.contains_key(&tcp_entry.local_address.port()) + { + if let Some(cmd) = map.get(&tcp_entry.inode) { + h.insert( + tcp_entry.local_address.port(), + PortDesc { + port: tcp_entry.local_address.port(), + desc: cmd.clone(), + }, + ); + } + } + } + + Ok(h.into_values().collect()) +} + +// ============================================================= +// Sending and receiving data +// ============================================================= + +// A channel that can receive packets from the remote side. +struct Channel { + packets: mpsc::Sender, // TODO: spsc probably +} + +struct Channels { + channels: Mutex>>, + removed: mpsc::Sender, // TODO: send an error? +} + +impl Channels { + fn new(removed: mpsc::Sender) -> Channels { + Channels { + channels: Mutex::new(HashMap::new()), + removed: removed, + } + } + + fn add(self: &Self, id: u16, channel: Channel) { + let mut channels = self.channels.lock().unwrap(); + channels.insert(id, Arc::new(channel)); + } + + fn get(self: &Self, id: u16) -> Option> { + let channels = self.channels.lock().unwrap(); + if let Some(channel) = channels.get(&id) { + Some(channel.clone()) + } else { + None + } + } + + async fn remove(self: &Self, id: u16) { + { + let mut channels = self.channels.lock().unwrap(); + channels.remove(&id); + } + _ = self.removed.send(id).await; + } +} + +enum Error { + Incomplete, + UnknownMessage, + Corrupt, +} + +enum Message { + Ping, + Connect(u64, u16), // Request to connect on a port from client to server. + Connected(u64, u16), // Sucessfully connected from server to client. + Close(u64), // Request to close from client to server. + Abort(u64), // Notify of close from server to client. + Closed(u64), // Response to Close or Abort. + Refresh, // Request to refresh list of ports from client. + Ports(Vec), // List of available ports from server to client. +} + +impl Message { + fn encode(self: &Message, dest: T) -> BytesMut { + use Message::*; + let result = BytesMut::new(); + match self { + Ping => { + result.put_u8(0x00); + } + Connect(channel, port) => { + result.put_u8(0x01); + result.put_u64(*channel); + result.put_u16(*port); + } + Connected(channel, port) => { + result.put_u8(0x02); + result.put_u64(*channel); + result.put_u16(*port); + } + Close(channel) => { + result.put_u8(0x03); + result.put_u64(*channel); + } + Abort(channel) => { + result.put_u8(0x04); + result.put_u64(*channel); + } + Closed(channel) => { + result.put_u8(0x05); + result.put_u64(*channel); + } + Refresh => { + result.put_u8(0x06); + } + Ports(ports) => { + result.put_u8(0x07); + + result.put_u16(u16::try_from(ports.len()).expect("Too many ports")); + for port in ports { + result.put_u16(port.port); + + let sliced = slice_up_to(&port.desc, u16::max_value().into()); + result.put_u16(u16::try_from(sliced.len()).unwrap()); + result.put_slice(sliced.as_bytes()); + } + } + }; + result + } + + fn decode(cursor: &mut Cursor<&[u8]>) -> Result { + use Message::*; + match get_u8(cursor)? { + 0x00 => Ok(Ping), + 0x01 => { + let channel = get_u64(cursor)?; + let port = get_u16(cursor)?; + Ok(Connect(channel, port)) + } + 0x02 => { + let channel = get_u64(cursor)?; + let port = get_u16(cursor)?; + Ok(Connected(channel, port)) + } + 0x03 => { + let channel = get_u64(cursor)?; + Ok(Close(channel)) + } + 0x04 => { + let channel = get_u64(cursor)?; + Ok(Abort(channel)) + } + 0x05 => { + let channel = get_u64(cursor)?; + Ok(Closed(channel)) + } + 0x06 => Ok(Refresh), + 0x07 => { + let count = get_u16(cursor)?; + + let mut ports = Vec::new(); + for i in 0..count { + let port = get_u16(cursor)?; + let length = get_u16(cursor)?; + + let data = get_bytes(cursor, length.into())?; + let desc = match std::str::from_utf8(&data[..]) { + Ok(s) => s.to_owned(), + Err(_) => return Err(Error::Corrupt), + }; + + ports.push(PortDesc { port, desc }); + } + Ok(Ports(ports)) + } + _ => Err(Error::Corrupt), + } + } +} + +fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { + if !cursor.has_remaining() { + return Err(Error::Incomplete); + } + Ok(cursor.get_u8()) +} + +fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { + if cursor.remaining() < 2 { + return Err(Error::Incomplete); + } + Ok(cursor.get_u16()) +} + +fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { + if cursor.remaining() < 8 { + return Err(Error::Incomplete); + } + Ok(cursor.get_u64()) +} + +fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { + if cursor.remaining() < length { + return Err(Error::Incomplete); + } + + Ok(cursor.copy_to_bytes(length)) +} + +pub fn slice_up_to(s: &str, max_len: usize) -> &str { + if max_len >= s.len() { + return s; + } + let mut idx = max_len; + while !s.is_char_boundary(idx) { + idx -= 1; + } + &s[..idx] +} + +struct ControllerState { + on_removed: mpsc::Receiver, + on_packet: mpsc::Receiver, +} + +struct ClientController { + channels: Arc, + state: Option, +} + +impl ClientController { + fn new() -> ClientController { + let (packets, on_packet) = mpsc::channel(32); + let (removed, on_removed) = mpsc::channel(32); + + let channels = Arc::new(Channels::new(removed)); + channels.add(0, Channel { packets }); + + ClientController { + channels, + state: Some(ControllerState { + on_removed, + on_packet, + }), + } + } + + fn channels(self: &Self) -> Arc { + self.channels.clone() + } + + fn start(self: &mut Self, stop: oneshot::Receiver<()>) { + if let Some(state) = self.state.take() { + tokio::spawn(async move { + let mut state = state; + tokio::select! { + _ = Self::process_channel_remove(&mut state.on_removed) => (), + _ = Self::process_packets(&mut state.on_packet) => (), + _ = stop => (), + } + }); + } + } + + async fn process_packets(packets: &mut mpsc::Receiver) { + while let Some(_) = packets.recv().await {} + } + + async fn process_channel_remove(removals: &mut mpsc::Receiver) { + while let Some(_) = removals.recv().await {} + } +} + +// TODO: Need flow control on the send side too because we don't want to +// block everybody if there's a slow reader on the other side. So the +// completion one-shot we send to the mux needs to go in a table until +// we get an ack back across the TCP channel. + +// A packet being sent across the channel. +struct OutgoingPacket { + channel: u16, // Channel 0 reserved as control channel. + data: BytesMut, + // Where we notify folks when the data has been sent. + sent: oneshot::Sender, +} + +// This is the "read" side of a forwarded connection; it reads from the read +// half of a TCP stream and sends those reads to into the multiplexing +// connection to be sent to the other side. +// +// (There's another function that handles the write side of a connection.) +async fn handle_read_side( + channel: u16, + read: &mut ReadHalf, + mux: mpsc::Sender, +) -> Result<(), tokio::io::Error> { + let mut buffer = BytesMut::with_capacity(u16::max_value().into()); + loop { + read.read_buf(&mut buffer).await?; + + let (tx, rx) = oneshot::channel::(); + let op = OutgoingPacket { + channel, + data: buffer, + sent: tx, + }; + if let Err(_) = mux.send(op).await { + return Ok(()); + } + + match rx.await { + Ok(b) => buffer = b, + Err(_) => return Ok(()), + } + + assert_eq!(buffer.capacity(), u16::max_value().into()); + buffer.clear(); + } +} + +struct IncomingPacket { + data: BytesMut, + sent: oneshot::Sender, +} + +// This is the "write" side of a forwarded connection; it receives data from +// the multiplexed connection and sends it out over the write half of a TCP +// stream to the attached process. +// +// (There's another function that handles the "read" side of a connection, +// which is a little more complex.) +async fn handle_write_side( + channel: u16, + packets: &mut mpsc::Receiver, // TODO: spsc I think + write: &mut WriteHalf, +) -> Result<(), tokio::io::Error> { + while let Some(IncomingPacket { data, sent }) = packets.recv().await { + // Write the data out to the write end of the TCP stream. + write.write_all(&data[..]).await?; + + // Now we've sent it, we can send the buffer back and let the caller + // know we wrote it. Literally don't care if they're still listening + // or not; should I care? + _ = sent.send(data); + } + + Ok(()) +} + +async fn handle_connection( + packets: &mut mpsc::Receiver, // TODO: spsc I think + mux: mpsc::Sender, + channel: u16, + socket: TcpStream, +) -> Result<(), tokio::io::Error> { + // Handle the read and write side of the socket separately. + let (mut read, mut write) = tokio::io::split(socket); + + let writer = handle_write_side(channel, packets, &mut write); + let reader = handle_read_side(channel, &mut read, mux); + + // Wait for both to be done. If either the reader or the writer completes + // with an error then we're just going to shut down the whole thing, + // closing the socket, &c. But either side can shut down cleanly and + // that's fine! + tokio::pin!(writer); + tokio::pin!(reader); + let (mut read_done, mut write_done) = (false, false); + while !(read_done && write_done) { + tokio::select! { + write_result = &mut writer, if !write_done => { + write_done = true; + if let Err(e) = write_result { + return Err(e); + } + }, + read_result = &mut reader, if !read_done => { + read_done = true; + if let Err(e) = read_result { + return Err(e); + } + }, + } + } + + Ok(()) +} + +async fn allocate_channel() -> ( + u16, + mpsc::Receiver, + mpsc::Sender, +) { + panic!("Not implemented"); +} + +// This is only on the client side of the connection. +async fn handle_listen(port: u16, channels: Arc) -> Result<(), tokio::io::Error> { + loop { + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; + loop { + // The second item contains the IP and port of the new connection. + // TODO: Handle shutdown correctly. + let (socket, _) = listener.accept().await?; + + tokio::spawn(async move { + // Finish the connect and then.... + let (channel, mut packets, mux) = allocate_channel().await; + + // ....handle the connection asynchronously... + let result = handle_connection(&mut packets, mux, channel, socket).await; + + // ...and then shut it down. + // close_channel(channel, result).await; + }); + } + } +} + +// Multiplex writes onto the given writer. Writes come in through the +// specified channel, and are multiplexed to the writer. +async fn mux_packets( + rx: &mut mpsc::Receiver, + writer: &mut T, +) { + while let Some(OutgoingPacket { + channel, + data, + sent, + }) = rx.recv().await + { + // Send the packet over the shared connection. + // TODO: Technically, for flow control purposes, we should mark + // the transmission as pending right now, and wait for an ack + // from the server in order to "complete" the write. OR we + // should do even better and caqp the number of outstanding + // writes per channel. + writer + .write_u16(channel) + .await + .expect("Error writing channel"); + + writer + .write_u16(u16::try_from(data.len()).expect("Multiplexed buffer too big")) + .await + .expect("Error writing length"); + + writer + .write_all(&data[..]) + .await + .expect("Error writing data"); + + sent.send(data).expect("Error notifying of completion"); + } +} + +fn new_muxer(writer: T) -> mpsc::Sender { + let mut writer = writer; + let (tx, mut rx) = mpsc::channel::(32); + tokio::spawn(async move { + mux_packets(&mut rx, &mut writer).await; + }); + tx +} + +async fn demux_packets(reader: &mut T, channels: Arc) { + let mut buffer = BytesMut::with_capacity(u16::max_value().into()); + loop { + let chid = reader + .read_u16() + .await + .expect("Error reading channel number from connection"); + let length = reader + .read_u16() + .await + .expect("Error reading length from connection"); + + let tail = buffer.split_off(length.into()); + reader + .read_exact(&mut buffer) + .await + .expect("Error reading data from connection"); + + if let Some(channel) = channels.get(chid) { + let (sent, is_sent) = oneshot::channel::(); + let packet = IncomingPacket { data: buffer, sent }; + + if let Err(_) = channel.packets.send(packet).await { + // TODO: Log Error + buffer = BytesMut::with_capacity(u16::max_value().into()); + channels.remove(chid).await; + } else { + match is_sent.await { + Ok(b) => { + buffer = b; + buffer.unsplit(tail); + } + Err(_) => { + // TODO: Log Error + buffer = BytesMut::with_capacity(u16::max_value().into()); + channels.remove(chid).await; + } + } + } + } + + buffer.clear(); + } +} + +async fn spawn_ssh(server: String) -> Result { + // let mut cmd = process::Command::new("echo"); + // cmd.stdout(Stdio::piped()); + // cmd.stdin(Stdio::piped()); + panic!("Not Implemented"); +} + +#[tokio::main] +async fn main() { + // Create the client-side controller. + let mut controller = ClientController::new(); + + // Spawn an SSH connection to the remote side. + let mut child = spawn_ssh("coder.doty-dev".into()) + .await + .expect("failed to spawn"); + + // Build a multiplexer around stdin. + let muxer = new_muxer(BufWriter::new( + child + .stdin + .take() + .expect("child did not have a handle to stdin"), + )); + + // Buffer input and output, FOR SPEED! + let mut reader = BufReader::new( + child + .stdout + .take() + .expect("child did not have a handle to stdout"), + ); + + let channels = controller.channels().clone(); + tokio::spawn(async move { + demux_packets(&mut reader, channels).await; + }); + + // let mut writer = + // Start up a task that's watching the SSH connection for completion. + // Presumably stdin and stdout will be closed and I'll get read/write + // errors and whatnot. + tokio::spawn(async move { + let status = child + .wait() + .await + .expect("child process encountered an error"); + + println!("child status was: {}", status); + }); + + // TODO: Wait for stdout to indicate readiness, or for a timeout to indicate it + // hasn't started. Note that some ssh implementations spit stuff + // into stdout that we ought to ignore, or there's a login MOTD or + // something, and we should just ignore it until we see the magic + // bytes. + + let (send_stop, stop) = oneshot::channel(); + controller.start(stop); + + // I guess we stop on a control-C? + + _ = send_stop.send(()); +} diff --git a/src/main.rs b/src/main.rs index fb654ca..e2b0e0e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,66 +1,314 @@ -use procfs::process::FDTarget; +use bytes::BytesMut; use std::collections::HashMap; +use std::io::Cursor; +use std::net::{Ipv4Addr, SocketAddrV4}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; +use tokio::net::TcpListener; +use tokio::process; +use tokio::sync::mpsc; +use tokio::sync::oneshot; -struct Entry { - port: u16, - desc: String, +mod message; +mod refresh; + +use message::Message; + +struct MessageWriter { + writer: T, } -fn get_entries() -> procfs::ProcResult> { - let all_procs = procfs::process::all_processes()?; +impl MessageWriter { + fn new(writer: T) -> MessageWriter { + MessageWriter { writer } + } + async fn write(self: &mut Self, msg: Message) -> Result<(), tokio::io::Error> { + // TODO: Optimize buffer usage please this is bad + let mut buffer = msg.encode(); + self.writer + .write_u32(buffer.len().try_into().expect("Message too large")) + .await?; + self.writer.write_buf(&mut buffer).await?; + self.writer.flush().await?; + Ok(()) + } +} - // build up a map between socket inodes and process stat info. Ignore any - // error we encounter as it probably means we have no access to that - // process or something. - let mut map: HashMap = HashMap::new(); - for p in all_procs { - if let Ok(process) = p { - if !process.is_alive() { - continue; // Ignore zombies. +async fn pump_write( + messages: &mut mpsc::Receiver, + writer: &mut MessageWriter, +) -> Result<(), tokio::io::Error> { + while let Some(msg) = messages.recv().await { + writer.write(msg).await?; + } + Ok(()) +} + +async fn server_read( + reader: &mut T, + writer: mpsc::Sender, +) -> Result<(), tokio::io::Error> { + eprintln!("< Processing packets..."); + loop { + let frame_length = reader.read_u32().await?; + + let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); + reader.read_buf(&mut data).await?; + + let mut cursor = Cursor::new(&data[..]); + let message = match Message::decode(&mut cursor) { + Ok(msg) => msg, + Err(_) => return Err(tokio::io::Error::from(tokio::io::ErrorKind::InvalidData)), + }; + + use Message::*; + match message { + Ping => (), + Refresh => { + let writer = writer.clone(); + tokio::spawn(async move { + let ports = match refresh::get_entries() { + Ok(ports) => ports, + Err(e) => { + eprintln!("< Error scanning: {:?}", e); + vec![] + } + }; + if let Err(e) = writer.send(Message::Ports(ports)).await { + // Writer has been closed for some reason, we can just quit.... I hope everything is OK? + eprintln!("< Warning: Error sending: {:?}", e); + } + }); } + _ => panic!("Unsupported: {:?}", message), + }; + } +} - if let (Ok(fds), Ok(cmd)) = (process.fd(), process.cmdline()) { - for fd in fds { - if let Ok(fd) = fd { - if let FDTarget::Socket(inode) = fd.target { - map.insert(inode, cmd.join(" ")); +async fn server_main( + reader: &mut Reader, + writer: &mut MessageWriter, +) -> Result<(), tokio::io::Error> { + // Jump into it... + let (msg_sender, mut msg_receiver) = mpsc::channel(32); + let writing = pump_write(&mut msg_receiver, writer); + let reading = server_read(reader, msg_sender); + tokio::pin!(reading); + tokio::pin!(writing); + + let (mut done_writing, mut done_reading) = (false, false); + loop { + tokio::select! { + result = &mut writing, if !done_writing => { + done_writing = true; + if let Err(e) = result { + return Err(e); + } + if done_reading && done_writing { + return Ok(()); + } + }, + result = &mut reading, if !done_reading => { + done_reading = true; + if let Err(e) = result { + return Err(e); + } + if done_reading && done_writing { + return Ok(()); + } + }, + } + } +} + +async fn spawn_ssh(server: &str) -> Result { + let mut cmd = process::Command::new("ssh"); + cmd.arg("-T").arg(server).arg("fwd").arg("--server"); + + cmd.stdout(std::process::Stdio::piped()); + cmd.stdin(std::process::Stdio::piped()); + cmd.spawn() +} + +async fn client_sync(reader: &mut T) -> Result<(), tokio::io::Error> { + eprintln!("> Waiting for synchronization marker..."); + let mut seen = 0; + while seen < 8 { + let byte = reader.read_u8().await?; + seen = if byte == 0 { seen + 1 } else { 0 }; + } + Ok(()) +} + +// struct Connection { +// connected: oneshot::Sender>, +// } + +// struct ConnectionTable { +// next_id: u64, +// connections: HashMap, +// } + +// type Connections = Arc>; + +async fn client_listen(port: u16) -> Result<(), tokio::io::Error> { + loop { + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)).await?; + loop { + // The second item contains the IP and port of the new connection. + // TODO: Handle shutdown correctly. + let (stream, _) = listener.accept().await?; + + eprintln!("> CONNECTION NOT IMPLEMENTED!"); + } + } +} + +async fn client_read( + reader: &mut T, + _writer: mpsc::Sender, +) -> Result<(), tokio::io::Error> { + let mut listeners: HashMap> = HashMap::new(); + + eprintln!("> Processing packets..."); + loop { + let frame_length = reader.read_u32().await?; + + let mut data = BytesMut::with_capacity(frame_length.try_into().unwrap()); + reader.read_buf(&mut data).await?; + + let mut cursor = Cursor::new(&data[..]); + let message = match Message::decode(&mut cursor) { + Ok(msg) => msg, + Err(_) => return Err(tokio::io::Error::from(tokio::io::ErrorKind::InvalidData)), + }; + + use Message::*; + match message { + Ping => (), + Ports(ports) => { + let mut new_listeners = HashMap::new(); + + println!("The following ports are available:"); + for port in ports { + println!(" {}: {}", port.port, port.desc); + + let port = port.port; + if let Some(l) = listeners.remove(&port) { + if !l.is_closed() { + // Listen could have failed! + new_listeners.insert(port, l); } } + + if !new_listeners.contains_key(&port) { + let (l, stop) = oneshot::channel(); + new_listeners.insert(port, l); + + tokio::spawn(async move { + let result = tokio::select! { + r = client_listen(port) => r, + _ = stop => Ok(()), + }; + if let Err(e) = result { + eprintln!("> Error listening on port {}: {:?}", port, e); + } + }); + } } + + listeners = new_listeners; } - } + _ => panic!("Unsupported: {:?}", message), + }; } - - let mut h: HashMap = HashMap::new(); - - // Go through all the listening IPv4 and IPv6 sockets and take the first - // instance of listening on each port *if* the address is loopback or - // unspecified. (TODO: Do we want this restriction really?) - let tcp = procfs::net::tcp()?; - let tcp6 = procfs::net::tcp6()?; - for tcp_entry in tcp.into_iter().chain(tcp6) { - if tcp_entry.state == procfs::net::TcpState::Listen - && (tcp_entry.local_address.ip().is_loopback() - || tcp_entry.local_address.ip().is_unspecified()) - && !h.contains_key(&tcp_entry.local_address.port()) - { - if let Some(cmd) = map.get(&tcp_entry.inode) { - h.insert( - tcp_entry.local_address.port(), - Entry { - port: tcp_entry.local_address.port(), - desc: cmd.clone(), - }, - ); - } - } - } - - Ok(h.into_values().collect()) } -fn main() { - for e in get_entries().unwrap() { - println!("{}: {}", e.port, e.desc); +async fn client_main( + reader: &mut Reader, + writer: &mut MessageWriter, +) -> Result<(), tokio::io::Error> { + // First synchronize; we're looking for the 8-zero marker that is the 64b sync marker. + // This helps us skip garbage like any kind of MOTD or whatnot. + client_sync(reader).await?; + + // Now kick things off with a listing of the ports... + eprintln!("> Sending initial list command..."); + writer.write(Message::Refresh).await?; + + // And now really get into it... + let (msg_sender, mut msg_receiver) = mpsc::channel(32); + let writing = pump_write(&mut msg_receiver, writer); + let reading = client_read(reader, msg_sender); + tokio::pin!(reading); + tokio::pin!(writing); + + let (mut done_writing, mut done_reading) = (false, false); + loop { + tokio::select! { + result = &mut writing, if !done_writing => { + done_writing = true; + if let Err(e) = result { + return Err(e); + } + if done_reading && done_writing { + return Ok(()); + } + }, + result = &mut reading, if !done_reading => { + done_reading = true; + if let Err(e) = result { + return Err(e); + } + if done_reading && done_writing { + return Ok(()); + } + }, + } + } +} + +#[tokio::main] +async fn main() { + let args: Vec = std::env::args().collect(); + let remote = &args[1]; + if remote == "--server" { + let mut reader = BufReader::new(tokio::io::stdin()); + let mut writer = BufWriter::new(tokio::io::stdout()); + + // Write the marker. + eprintln!("< Writing marker..."); + writer + .write_u64(0x00_00_00_00_00_00_00_00) + .await + .expect("Error writing marker"); + + writer.flush().await.expect("Error flushing buffer"); + eprintln!("< Done!"); + + let mut writer = MessageWriter::new(writer); + + if let Err(e) = server_main(&mut reader, &mut writer).await { + eprintln!("Error: {:?}", e); + } + } else { + let mut child = spawn_ssh(remote).await.expect("failed to spawn"); + + let mut writer = MessageWriter::new(BufWriter::new( + child + .stdin + .take() + .expect("child did not have a handle to stdout"), + )); + + let mut reader = BufReader::new( + child + .stdout + .take() + .expect("child did not have a handle to stdout"), + ); + + if let Err(e) = client_main(&mut reader, &mut writer).await { + eprintln!("Error: {:?}", e); + } } } diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..f0f5801 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,215 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::io::Cursor; + +#[derive(Debug, PartialEq)] +pub enum MessageError { + Incomplete, + UnknownMessage, + Corrupt, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct PortDesc { + pub port: u16, + pub desc: String, +} + +#[derive(Debug, PartialEq)] +pub enum Message { + Ping, + Connect(u64, u16), // Request to connect on a port from client to server. + Connected(u64, u16), // Sucessfully connected from server to client. + Close(u64), // Request to close from client to server. + Abort(u64), // Notify of close from server to client. + Closed(u64), // Response to Close or Abort. + Refresh, // Request to refresh list of ports from client. + Ports(Vec), // List of available ports from server to client. + Data(u64, Bytes), // Transmit data. +} + +impl Message { + pub fn encode(self: &Message) -> BytesMut { + use Message::*; + let mut result = BytesMut::new(); + match self { + Ping => { + result.put_u8(0x00); + } + Connect(channel, port) => { + result.put_u8(0x01); + result.put_u64(*channel); + result.put_u16(*port); + } + Connected(channel, port) => { + result.put_u8(0x02); + result.put_u64(*channel); + result.put_u16(*port); + } + Close(channel) => { + result.put_u8(0x03); + result.put_u64(*channel); + } + Abort(channel) => { + result.put_u8(0x04); + result.put_u64(*channel); + } + Closed(channel) => { + result.put_u8(0x05); + result.put_u64(*channel); + } + Refresh => { + result.put_u8(0x06); + } + Ports(ports) => { + result.put_u8(0x07); + + result.put_u16(ports.len().try_into().expect("Too many ports")); + for port in ports { + result.put_u16(port.port); + + let sliced = slice_up_to(&port.desc, u16::max_value().into()); + result.put_u16(sliced.len().try_into().unwrap()); + result.put_slice(sliced.as_bytes()); + } + } + Data(channel, bytes) => { + result.put_u8(0x08); + result.put_u64(*channel); + result.put_u16(bytes.len().try_into().expect("Payload too big")); + result.put_slice(bytes); // I hate that this copies. We should make this an async write probably. + } + }; + result + } + + pub fn decode(cursor: &mut Cursor<&[u8]>) -> Result { + use Message::*; + match get_u8(cursor)? { + 0x00 => Ok(Ping), + 0x01 => { + let channel = get_u64(cursor)?; + let port = get_u16(cursor)?; + Ok(Connect(channel, port)) + } + 0x02 => { + let channel = get_u64(cursor)?; + let port = get_u16(cursor)?; + Ok(Connected(channel, port)) + } + 0x03 => { + let channel = get_u64(cursor)?; + Ok(Close(channel)) + } + 0x04 => { + let channel = get_u64(cursor)?; + Ok(Abort(channel)) + } + 0x05 => { + let channel = get_u64(cursor)?; + Ok(Closed(channel)) + } + 0x06 => Ok(Refresh), + 0x07 => { + let count = get_u16(cursor)?; + + let mut ports = Vec::new(); + for _ in 0..count { + let port = get_u16(cursor)?; + let length = get_u16(cursor)?; + + let data = get_bytes(cursor, length.into())?; + let desc = match std::str::from_utf8(&data[..]) { + Ok(s) => s.to_owned(), + Err(_) => return Err(MessageError::Corrupt), + }; + + ports.push(PortDesc { port, desc }); + } + Ok(Ports(ports)) + } + 0x08 => { + let channel = get_u64(cursor)?; + let length = get_u16(cursor)?; + let data = get_bytes(cursor, length.into())?; + Ok(Data(channel, data)) + } + _ => Err(MessageError::UnknownMessage), + } + } +} + +#[cfg(test)] +mod message_tests { + use crate::message::Message; + use crate::message::Message::*; + use crate::message::PortDesc; + + fn assert_round_trip(message: Message) { + let encoded = message.encode(); + let mut cursor = std::io::Cursor::new(&encoded[..]); + let result = Message::decode(&mut cursor); + assert_eq!(Ok(message), result); + } + + #[test] + fn round_trip() { + assert_round_trip(Ping); + assert_round_trip(Connect(0x1234567890123456, 0x1234)); + assert_round_trip(Connected(0x1234567890123456, 0x1234)); + assert_round_trip(Close(0x1234567890123456)); + assert_round_trip(Abort(0x1234567890123456)); + assert_round_trip(Closed(0x1234567890123456)); + assert_round_trip(Refresh); + assert_round_trip(Ports(vec![ + PortDesc { + port: 8080, + desc: "query-service".to_string(), + }, + PortDesc { + port: 9090, + desc: "metadata-library".to_string(), + }, + ])); + assert_round_trip(Data(0x1234567890123456, vec![1, 2, 3, 4].into())); + } +} + +fn get_u8(cursor: &mut Cursor<&[u8]>) -> Result { + if !cursor.has_remaining() { + return Err(MessageError::Incomplete); + } + Ok(cursor.get_u8()) +} + +fn get_u16(cursor: &mut Cursor<&[u8]>) -> Result { + if cursor.remaining() < 2 { + return Err(MessageError::Incomplete); + } + Ok(cursor.get_u16()) +} + +fn get_u64(cursor: &mut Cursor<&[u8]>) -> Result { + if cursor.remaining() < 8 { + return Err(MessageError::Incomplete); + } + Ok(cursor.get_u64()) +} + +fn get_bytes(cursor: &mut Cursor<&[u8]>, length: usize) -> Result { + if cursor.remaining() < length { + return Err(MessageError::Incomplete); + } + + Ok(cursor.copy_to_bytes(length)) +} + +pub fn slice_up_to(s: &str, max_len: usize) -> &str { + if max_len >= s.len() { + return s; + } + let mut idx = max_len; + while !s.is_char_boundary(idx) { + idx -= 1; + } + &s[..idx] +} diff --git a/src/refresh.rs b/src/refresh.rs new file mode 100644 index 0000000..2255e44 --- /dev/null +++ b/src/refresh.rs @@ -0,0 +1,56 @@ +use crate::message::PortDesc; +use procfs::process::FDTarget; +use std::collections::HashMap; + +pub fn get_entries() -> procfs::ProcResult> { + let all_procs = procfs::process::all_processes()?; + + // build up a map between socket inodes and process stat info. Ignore any + // error we encounter as it probably means we have no access to that + // process or something. + let mut map: HashMap = HashMap::new(); + for p in all_procs { + if let Ok(process) = p { + if !process.is_alive() { + continue; // Ignore zombies. + } + + if let (Ok(fds), Ok(cmd)) = (process.fd(), process.cmdline()) { + for fd in fds { + if let Ok(fd) = fd { + if let FDTarget::Socket(inode) = fd.target { + map.insert(inode, cmd.join(" ")); + } + } + } + } + } + } + + let mut h: HashMap = HashMap::new(); + + // Go through all the listening IPv4 and IPv6 sockets and take the first + // instance of listening on each port *if* the address is loopback or + // unspecified. (TODO: Do we want this restriction really?) + let tcp = procfs::net::tcp()?; + let tcp6 = procfs::net::tcp6()?; + for tcp_entry in tcp.into_iter().chain(tcp6) { + if tcp_entry.state == procfs::net::TcpState::Listen + && (tcp_entry.local_address.ip().is_loopback() + || tcp_entry.local_address.ip().is_unspecified()) + && !h.contains_key(&tcp_entry.local_address.port()) + { + if let Some(cmd) = map.get(&tcp_entry.inode) { + h.insert( + tcp_entry.local_address.port(), + PortDesc { + port: tcp_entry.local_address.port(), + desc: cmd.clone(), + }, + ); + } + } + } + + Ok(h.into_values().collect()) +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..f4fc40a --- /dev/null +++ b/test.py @@ -0,0 +1,31 @@ +#!/bin/env python3 +import os +import subprocess +import sys + + +def local(*args): + subprocess.run(list(args), check=True) + + +def ssh(remote, *args): + subprocess.run(["ssh", remote] + list(args), check=True, capture_output=True) + + +def main(args): + local("cargo", "build") + + remote = args[1] + print(f"Copying file to {remote}...") + subprocess.run( + ["scp", "target/debug/fwd", f"{remote}:bin/fwd"], + check=True, + capture_output=True, + ) + + print(f"Starting process...") + subprocess.run(["target/debug/fwd", remote]) + + +if __name__ == "__main__": + main(sys.argv)