diff --git a/src/lib.rs b/src/lib.rs index b17eeab..a6983cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,22 +161,38 @@ async fn server_main( } } -async fn client_sync(reader: &mut T) -> Result<()> { - // TODO: While we're waiting here we should be echoing everything we read. - // We should also be proxying *our* stdin to the processes stdin, - // and turn that off when we've synchronized. That way we can - // handle passwords and the like for authentication. +async fn client_sync( + reader: &mut Read, + writer: &mut Write, +) -> Result<(), tokio::io::Error> { eprintln!("> Waiting for synchronization marker..."); - let mut seen = 0; - while seen < 8 { - let byte = reader.read_u8().await?; - if byte == 0 { - seen += 1; - } else { - tokio::io::stdout().write_u8(byte).await?; - } + + // Run these two loops in parallel; the copy of stdin should stop when + // we've seen the marker from the client. If the pipe closes for whatever + // reason then obviously we quit. + let mut stdout = tokio::io::stdout(); + tokio::select! { + result = async { + let mut stdin = tokio::io::stdin(); + tokio::io::copy(&mut stdin, writer).await + } => match result { + Ok(_) => Ok(()), + Err(e) => Err(e), + }, + result = async { + let mut seen = 0; + while seen < 8 { + let byte = reader.read_u8().await?; + if byte == 0 { + seen += 1; + } else { + stdout.write_u8(byte).await?; + } + } + + Ok::<_, tokio::io::Error>(()) + } => result, } - Ok(()) } async fn client_handle_connection( @@ -388,12 +404,10 @@ pub async fn run_client(remote: &str) { // TODO: Drive a reconnect loop let mut child = spawn_ssh(remote).await.expect("failed to spawn"); - let mut writer = MessageWriter::new(BufWriter::new( - child - .stdin - .take() - .expect("child did not have a handle to stdout"), - )); + let mut writer = child + .stdin + .take() + .expect("child did not have a handle to stdin"); let mut reader = BufReader::new( child @@ -402,11 +416,12 @@ pub async fn run_client(remote: &str) { .expect("child did not have a handle to stdout"), ); - if let Err(e) = client_sync(&mut reader).await { + if let Err(e) = client_sync(&mut reader, &mut writer).await { eprintln!("Error synchronizing: {:?}", e); return; } + let mut writer = MessageWriter::new(BufWriter::new(writer)); let mut reader = MessageReader::new(reader); if let Err(e) = client_main(&mut reader, &mut writer).await { eprintln!("Error: {:?}", e);