295 lines
9.7 KiB
Rust
295 lines
9.7 KiB
Rust
/// SSH context: execute commands over an SSH connection
|
|
/// Context driver: Copies over SFTP with ssh2, executes commands over ssh2 channels
|
|
use super::api::ContextDriver;
|
|
use log::debug;
|
|
use ssh2;
|
|
use std::fs;
|
|
use std::io::Write;
|
|
use std::io::{self, Read};
|
|
use std::net::TcpStream;
|
|
#[cfg(unix)]
|
|
use std::os::unix::process::ExitStatusExt;
|
|
use std::path::{Path, PathBuf};
|
|
use std::process::ExitStatus;
|
|
|
|
pub fn connect_ssh(host: &str, user: Option<&str>, port: Option<u16>) -> io::Result<ssh2::Session> {
|
|
let port = port.unwrap_or(22);
|
|
let tcp = TcpStream::connect((host, port))?;
|
|
let mut session = ssh2::Session::new().unwrap();
|
|
session.set_tcp_stream(tcp);
|
|
session
|
|
.handshake()
|
|
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
|
|
|
|
// Username: if set, use it; else, use local 'USER', and if unset, use 'root'
|
|
let local_user_string;
|
|
let user = match user {
|
|
Some(u) => u,
|
|
None => {
|
|
local_user_string = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
|
|
&local_user_string
|
|
}
|
|
};
|
|
|
|
if session.userauth_agent(user).is_ok() {
|
|
return Ok(session);
|
|
}
|
|
|
|
if !session.authenticated() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::PermissionDenied,
|
|
"SSH authentication failed (tried agent)",
|
|
));
|
|
}
|
|
|
|
Ok(session)
|
|
}
|
|
|
|
pub struct SshDriver {
|
|
pub host: String,
|
|
pub user: Option<String>,
|
|
pub port: Option<u16>,
|
|
}
|
|
|
|
impl ContextDriver for SshDriver {
|
|
fn ensure_available(&self, src: &Path, dest_root: &str) -> io::Result<PathBuf> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let sftp = sess.sftp().map_err(io::Error::other)?;
|
|
|
|
let file_name = src
|
|
.file_name()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "src has no filename"))?;
|
|
let remote_dest = Path::new(dest_root).join(file_name);
|
|
|
|
debug!("Uploading {:?} to remote {:?}", src, remote_dest);
|
|
|
|
Self::upload_recursive(&sftp, src, &remote_dest)?;
|
|
|
|
Ok(remote_dest)
|
|
}
|
|
|
|
fn retrieve_path(&self, src: &Path, dest: &Path) -> io::Result<()> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let sftp = sess.sftp().map_err(io::Error::other)?;
|
|
|
|
debug!("Downloading remote {:?} to {:?}", src, dest);
|
|
Self::download_recursive(&sftp, src, dest)
|
|
}
|
|
|
|
fn list_files(&self, path: &Path) -> io::Result<Vec<PathBuf>> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let sftp = sess.sftp().map_err(io::Error::other)?;
|
|
|
|
let mut files = Vec::new();
|
|
let entries = sftp.readdir(path).map_err(io::Error::other)?;
|
|
for (p, _) in entries {
|
|
let file_name = p.file_name().unwrap();
|
|
if file_name == "." || file_name == ".." {
|
|
continue;
|
|
}
|
|
files.push(p);
|
|
}
|
|
Ok(files)
|
|
}
|
|
|
|
fn run(
|
|
&self,
|
|
program: &str,
|
|
args: &[String],
|
|
env: &[(String, String)],
|
|
cwd: Option<&str>,
|
|
) -> io::Result<std::process::ExitStatus> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let mut channel = sess.channel_session().map_err(io::Error::other)?;
|
|
|
|
// Construct command line with env vars
|
|
// TODO: No, use ssh2 channel.set_env
|
|
let mut cmd_line = String::new();
|
|
for (key, value) in env {
|
|
cmd_line.push_str(&format!(
|
|
"export {}='{}'; ",
|
|
key,
|
|
value.replace("'", "'\\''")
|
|
));
|
|
}
|
|
if let Some(dir) = cwd {
|
|
cmd_line.push_str(&format!("cd {} && ", dir));
|
|
}
|
|
cmd_line.push_str(program);
|
|
for arg in args {
|
|
cmd_line.push(' ');
|
|
cmd_line.push_str(arg); // TODO: escape
|
|
}
|
|
|
|
debug!("Executing SSH command: {}", cmd_line);
|
|
|
|
channel
|
|
.request_pty("xterm", None, None)
|
|
.map_err(|e| io::Error::other(format!("Failed to request PTY: {}", e)))?;
|
|
|
|
channel.exec(&cmd_line).map_err(io::Error::other)?;
|
|
|
|
let mut stdout_stream = channel.stream(0);
|
|
let mut stdout = io::stdout();
|
|
io::copy(&mut stdout_stream, &mut stdout)?;
|
|
|
|
channel.wait_close().map_err(io::Error::other)?;
|
|
|
|
let code = channel.exit_status().unwrap_or(-1);
|
|
Ok(ExitStatus::from_raw(code))
|
|
}
|
|
|
|
fn run_output(
|
|
&self,
|
|
program: &str,
|
|
args: &[String],
|
|
env: &[(String, String)],
|
|
cwd: Option<&str>,
|
|
) -> io::Result<std::process::Output> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let mut channel = sess.channel_session().map_err(io::Error::other)?;
|
|
|
|
// Construct command line with env vars
|
|
let mut cmd_line = String::new();
|
|
for (key, value) in env {
|
|
cmd_line.push_str(&format!(
|
|
"export {}='{}'; ",
|
|
key,
|
|
value.replace("'", "'\\''")
|
|
));
|
|
}
|
|
if let Some(dir) = cwd {
|
|
cmd_line.push_str(&format!("cd {} && ", dir));
|
|
}
|
|
cmd_line.push_str(program);
|
|
for arg in args {
|
|
cmd_line.push(' ');
|
|
cmd_line.push_str(arg); // TODO: escape
|
|
}
|
|
|
|
channel.exec(&cmd_line).map_err(io::Error::other)?;
|
|
|
|
let mut stdout = Vec::new();
|
|
channel.read_to_end(&mut stdout)?;
|
|
|
|
let mut stderr = Vec::new();
|
|
channel.stderr().read_to_end(&mut stderr)?;
|
|
|
|
channel.wait_close().map_err(io::Error::other)?;
|
|
|
|
#[cfg(unix)]
|
|
let status = std::process::ExitStatus::from_raw(channel.exit_status().unwrap_or(-1));
|
|
|
|
#[cfg(not(unix))]
|
|
let status = {
|
|
panic!("SSH output capture only supported on Unix");
|
|
};
|
|
|
|
Ok(std::process::Output {
|
|
status,
|
|
stdout,
|
|
stderr,
|
|
})
|
|
}
|
|
|
|
fn create_temp_dir(&self) -> io::Result<String> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let mut channel = sess.channel_session().map_err(io::Error::other)?;
|
|
|
|
channel.exec("mktemp -d").map_err(io::Error::other)?;
|
|
|
|
let mut stdout = String::new();
|
|
channel.read_to_string(&mut stdout)?;
|
|
channel.wait_close().map_err(io::Error::other)?;
|
|
|
|
if channel.exit_status().unwrap_or(-1) != 0 {
|
|
return Err(io::Error::other(
|
|
"Failed to create remote temporary directory",
|
|
));
|
|
}
|
|
|
|
Ok(stdout.trim().to_string())
|
|
}
|
|
|
|
fn copy_path(&self, src: &Path, dest: &Path) -> io::Result<()> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let mut channel = sess.channel_session().map_err(io::Error::other)?;
|
|
// TODO: use sftp
|
|
let cmd = format!("cp -a {:?} {:?}", src, dest);
|
|
debug!("Executing remote copy: {}", cmd);
|
|
channel.exec(&cmd).map_err(io::Error::other)?;
|
|
channel.wait_close().map_err(io::Error::other)?;
|
|
if channel.exit_status().unwrap_or(-1) != 0 {
|
|
return Err(io::Error::other(format!("Remote copy failed: {}", cmd)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn read_file(&self, path: &Path) -> io::Result<String> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let sftp = sess.sftp().map_err(io::Error::other)?;
|
|
let mut remote_file = sftp.open(path).map_err(io::Error::other)?;
|
|
let mut content = String::new();
|
|
remote_file.read_to_string(&mut content)?;
|
|
Ok(content)
|
|
}
|
|
|
|
fn write_file(&self, path: &Path, content: &str) -> io::Result<()> {
|
|
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
|
|
let sftp = sess.sftp().map_err(io::Error::other)?;
|
|
if let Some(parent) = path.parent() {
|
|
let _ = sftp.mkdir(parent, 0o755);
|
|
}
|
|
let mut remote_file = sftp.create(path).map_err(io::Error::other)?;
|
|
remote_file.write_all(content.as_bytes())?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl SshDriver {
|
|
fn upload_recursive(sftp: &ssh2::Sftp, src: &Path, dest: &Path) -> io::Result<()> {
|
|
if src.is_dir() {
|
|
// Create dir
|
|
let _ = sftp.mkdir(dest, 0o755);
|
|
for entry in fs::read_dir(src)? {
|
|
let entry = entry?;
|
|
let path = entry.path();
|
|
let name = entry.file_name();
|
|
let dest_path = dest.join(name);
|
|
Self::upload_recursive(sftp, &path, &dest_path)?;
|
|
}
|
|
} else {
|
|
let mut file = fs::File::open(src)?;
|
|
let mut remote_file = sftp.create(dest).map_err(|e| {
|
|
io::Error::other(format!("Failed to create remote file {:?}: {}", dest, e))
|
|
})?;
|
|
io::copy(&mut file, &mut remote_file)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn download_recursive(sftp: &ssh2::Sftp, src: &Path, dest: &Path) -> io::Result<()> {
|
|
let stat = sftp
|
|
.stat(src)
|
|
.map_err(|e| io::Error::other(format!("Remote stat failed for {:?}: {}", src, e)))?;
|
|
|
|
if stat.is_dir() {
|
|
fs::create_dir_all(dest)?;
|
|
let entries = sftp.readdir(src).map_err(io::Error::other)?;
|
|
for (path, _) in entries {
|
|
let file_name = path.file_name().unwrap();
|
|
if file_name == "." || file_name == ".." {
|
|
continue;
|
|
}
|
|
let new_dest = dest.join(file_name);
|
|
Self::download_recursive(sftp, &path, &new_dest)?;
|
|
}
|
|
} else {
|
|
let mut remote_file = sftp.open(src).map_err(io::Error::other)?;
|
|
let mut local_file = fs::File::create(dest)?;
|
|
io::copy(&mut remote_file, &mut local_file)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|