use super::api::{CommandRunner, ContextDriver}; use log::debug; use std::fs; /// SSH context: execute commands over an SSH connection /// Context driver: Copies over SFTP with ssh2 /// Command runner: Executes commands over an ssh2 channel (with PTY) use std::io::{self, Read}; use std::net::TcpStream; #[cfg(unix)] use std::os::unix::process::ExitStatusExt; use std::path::{Path, PathBuf}; use std::process::ExitStatus; pub fn connect_ssh(host: &str, user: Option<&str>, port: Option) -> io::Result { let port = port.unwrap_or(22); let tcp = TcpStream::connect((host, port))?; let mut session = ssh2::Session::new().unwrap(); session.set_tcp_stream(tcp); session .handshake() .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?; // Username: if set, use it; else, use local 'USER', and if unset, use 'root' let local_user_string; let user = match user { Some(u) => u, None => { local_user_string = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); &local_user_string } }; if session.userauth_agent(user).is_ok() { return Ok(session); } if !session.authenticated() { return Err(io::Error::new( io::ErrorKind::PermissionDenied, "SSH authentication failed (tried agent)", )); } Ok(session) } pub struct SshDriver { pub host: String, pub user: Option, pub port: Option, } impl ContextDriver for SshDriver { fn ensure_available(&self, src: &Path, dest_root: &str) -> io::Result { let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?; let sftp = sess.sftp().map_err(io::Error::other)?; let file_name = src .file_name() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "src has no filename"))?; let remote_dest = Path::new(dest_root).join(file_name); debug!("Uploading {:?} to remote {:?}", src, remote_dest); Self::upload_recursive(&sftp, src, &remote_dest)?; Ok(remote_dest) } fn create_runner(&self, program: String) -> Box { Box::new(SshRunner { host: self.host.clone(), user: self.user.clone(), port: self.port, program, args: Vec::new(), }) } fn prepare_work_dir(&self) -> io::Result { let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?; let mut channel = sess.channel_session().map_err(io::Error::other)?; channel.exec("mktemp -d").map_err(io::Error::other)?; let mut stdout = String::new(); channel.read_to_string(&mut stdout)?; channel.wait_close().map_err(io::Error::other)?; if channel.exit_status().unwrap_or(-1) != 0 { return Err(io::Error::other( "Failed to create remote temporary directory", )); } Ok(stdout.trim().to_string()) } } impl SshDriver { fn upload_recursive(sftp: &ssh2::Sftp, src: &Path, dest: &Path) -> io::Result<()> { if src.is_dir() { // Create dir let _ = sftp.mkdir(dest, 0o755); for entry in fs::read_dir(src)? { let entry = entry?; let path = entry.path(); let name = entry.file_name(); let dest_path = dest.join(name); Self::upload_recursive(sftp, &path, &dest_path)?; } } else { let mut file = fs::File::open(src)?; let mut remote_file = sftp.create(dest).map_err(|e| { io::Error::other(format!("Failed to create remote file {:?}: {}", dest, e)) })?; io::copy(&mut file, &mut remote_file)?; } Ok(()) } } pub struct SshRunner { pub host: String, pub user: Option, pub port: Option, pub program: String, pub args: Vec, } impl SshRunner { fn prepare_channel(&self) -> io::Result<(ssh2::Channel, String)> { let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?; let channel = sess.channel_session().map_err(io::Error::other)?; // Construct command line let mut cmd_line = self.program.clone(); for arg in &self.args { cmd_line.push(' '); cmd_line.push_str(arg); // TODO: escape } Ok((channel, cmd_line)) } } impl CommandRunner for SshRunner { fn add_arg(&mut self, arg: String) { self.args.push(arg); } fn status(&mut self) -> io::Result { let (mut channel, cmd_line) = self.prepare_channel()?; debug!("Executing SSH command: {}", cmd_line); channel .request_pty("xterm", None, None) .map_err(|e| io::Error::other(format!("Failed to request PTY: {}", e)))?; channel.exec(&cmd_line).map_err(io::Error::other)?; let mut stdout_stream = channel.stream(0); let mut stdout = io::stdout(); io::copy(&mut stdout_stream, &mut stdout)?; channel.wait_close().map_err(io::Error::other)?; let code = channel.exit_status().unwrap_or(-1); Ok(ExitStatus::from_raw(code)) } fn output(&mut self) -> io::Result { let (mut channel, cmd_line) = self.prepare_channel()?; channel.exec(&cmd_line).map_err(io::Error::other)?; let mut stdout = Vec::new(); channel.read_to_end(&mut stdout)?; let mut stderr = Vec::new(); channel.stderr().read_to_end(&mut stderr)?; channel.wait_close().map_err(io::Error::other)?; #[cfg(unix)] let status = std::process::ExitStatus::from_raw(channel.exit_status().unwrap_or(-1)); #[cfg(not(unix))] let status = { panic!("SSH output capture only supported on Unix"); }; Ok(std::process::Output { status, stdout, stderr, }) } }