Files
pkh/src/context/ssh.rs
2025-12-20 00:06:07 +01:00

295 lines
9.7 KiB
Rust

/// SSH context: execute commands over an SSH connection
/// Context driver: Copies over SFTP with ssh2, executes commands over ssh2 channels
use super::api::ContextDriver;
use log::debug;
use ssh2;
use std::fs;
use std::io::Write;
use std::io::{self, Read};
use std::net::TcpStream;
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt;
use std::path::{Path, PathBuf};
use std::process::ExitStatus;
pub fn connect_ssh(host: &str, user: Option<&str>, port: Option<u16>) -> io::Result<ssh2::Session> {
let port = port.unwrap_or(22);
let tcp = TcpStream::connect((host, port))?;
let mut session = ssh2::Session::new().unwrap();
session.set_tcp_stream(tcp);
session
.handshake()
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
// Username: if set, use it; else, use local 'USER', and if unset, use 'root'
let local_user_string;
let user = match user {
Some(u) => u,
None => {
local_user_string = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
&local_user_string
}
};
if session.userauth_agent(user).is_ok() {
return Ok(session);
}
if !session.authenticated() {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"SSH authentication failed (tried agent)",
));
}
Ok(session)
}
pub struct SshDriver {
pub host: String,
pub user: Option<String>,
pub port: Option<u16>,
}
impl ContextDriver for SshDriver {
fn ensure_available(&self, src: &Path, dest_root: &str) -> io::Result<PathBuf> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
let file_name = src
.file_name()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "src has no filename"))?;
let remote_dest = Path::new(dest_root).join(file_name);
debug!("Uploading {:?} to remote {:?}", src, remote_dest);
Self::upload_recursive(&sftp, src, &remote_dest)?;
Ok(remote_dest)
}
fn retrieve_path(&self, src: &Path, dest: &Path) -> io::Result<()> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
debug!("Downloading remote {:?} to {:?}", src, dest);
Self::download_recursive(&sftp, src, dest)
}
fn list_files(&self, path: &Path) -> io::Result<Vec<PathBuf>> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
let mut files = Vec::new();
let entries = sftp.readdir(path).map_err(io::Error::other)?;
for (p, _) in entries {
let file_name = p.file_name().unwrap();
if file_name == "." || file_name == ".." {
continue;
}
files.push(p);
}
Ok(files)
}
fn run(
&self,
program: &str,
args: &[String],
env: &[(String, String)],
cwd: Option<&str>,
) -> io::Result<std::process::ExitStatus> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
// Construct command line with env vars
// TODO: No, use ssh2 channel.set_env
let mut cmd_line = String::new();
for (key, value) in env {
cmd_line.push_str(&format!(
"export {}='{}'; ",
key,
value.replace("'", "'\\''")
));
}
if let Some(dir) = cwd {
cmd_line.push_str(&format!("cd {} && ", dir));
}
cmd_line.push_str(program);
for arg in args {
cmd_line.push(' ');
cmd_line.push_str(arg); // TODO: escape
}
debug!("Executing SSH command: {}", cmd_line);
channel
.request_pty("xterm", None, None)
.map_err(|e| io::Error::other(format!("Failed to request PTY: {}", e)))?;
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout_stream = channel.stream(0);
let mut stdout = io::stdout();
io::copy(&mut stdout_stream, &mut stdout)?;
channel.wait_close().map_err(io::Error::other)?;
let code = channel.exit_status().unwrap_or(-1);
Ok(ExitStatus::from_raw(code))
}
fn run_output(
&self,
program: &str,
args: &[String],
env: &[(String, String)],
cwd: Option<&str>,
) -> io::Result<std::process::Output> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
// Construct command line with env vars
let mut cmd_line = String::new();
for (key, value) in env {
cmd_line.push_str(&format!(
"export {}='{}'; ",
key,
value.replace("'", "'\\''")
));
}
if let Some(dir) = cwd {
cmd_line.push_str(&format!("cd {} && ", dir));
}
cmd_line.push_str(program);
for arg in args {
cmd_line.push(' ');
cmd_line.push_str(arg); // TODO: escape
}
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout = Vec::new();
channel.read_to_end(&mut stdout)?;
let mut stderr = Vec::new();
channel.stderr().read_to_end(&mut stderr)?;
channel.wait_close().map_err(io::Error::other)?;
#[cfg(unix)]
let status = std::process::ExitStatus::from_raw(channel.exit_status().unwrap_or(-1));
#[cfg(not(unix))]
let status = {
panic!("SSH output capture only supported on Unix");
};
Ok(std::process::Output {
status,
stdout,
stderr,
})
}
fn create_temp_dir(&self) -> io::Result<String> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
channel.exec("mktemp -d").map_err(io::Error::other)?;
let mut stdout = String::new();
channel.read_to_string(&mut stdout)?;
channel.wait_close().map_err(io::Error::other)?;
if channel.exit_status().unwrap_or(-1) != 0 {
return Err(io::Error::other(
"Failed to create remote temporary directory",
));
}
Ok(stdout.trim().to_string())
}
fn copy_path(&self, src: &Path, dest: &Path) -> io::Result<()> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
// TODO: use sftp
let cmd = format!("cp -a {:?} {:?}", src, dest);
debug!("Executing remote copy: {}", cmd);
channel.exec(&cmd).map_err(io::Error::other)?;
channel.wait_close().map_err(io::Error::other)?;
if channel.exit_status().unwrap_or(-1) != 0 {
return Err(io::Error::other(format!("Remote copy failed: {}", cmd)));
}
Ok(())
}
fn read_file(&self, path: &Path) -> io::Result<String> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
let mut remote_file = sftp.open(path).map_err(io::Error::other)?;
let mut content = String::new();
remote_file.read_to_string(&mut content)?;
Ok(content)
}
fn write_file(&self, path: &Path, content: &str) -> io::Result<()> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
if let Some(parent) = path.parent() {
let _ = sftp.mkdir(parent, 0o755);
}
let mut remote_file = sftp.create(path).map_err(io::Error::other)?;
remote_file.write_all(content.as_bytes())?;
Ok(())
}
}
impl SshDriver {
fn upload_recursive(sftp: &ssh2::Sftp, src: &Path, dest: &Path) -> io::Result<()> {
if src.is_dir() {
// Create dir
let _ = sftp.mkdir(dest, 0o755);
for entry in fs::read_dir(src)? {
let entry = entry?;
let path = entry.path();
let name = entry.file_name();
let dest_path = dest.join(name);
Self::upload_recursive(sftp, &path, &dest_path)?;
}
} else {
let mut file = fs::File::open(src)?;
let mut remote_file = sftp.create(dest).map_err(|e| {
io::Error::other(format!("Failed to create remote file {:?}: {}", dest, e))
})?;
io::copy(&mut file, &mut remote_file)?;
}
Ok(())
}
fn download_recursive(sftp: &ssh2::Sftp, src: &Path, dest: &Path) -> io::Result<()> {
let stat = sftp
.stat(src)
.map_err(|e| io::Error::other(format!("Remote stat failed for {:?}: {}", src, e)))?;
if stat.is_dir() {
fs::create_dir_all(dest)?;
let entries = sftp.readdir(src).map_err(io::Error::other)?;
for (path, _) in entries {
let file_name = path.file_name().unwrap();
if file_name == "." || file_name == ".." {
continue;
}
let new_dest = dest.join(file_name);
Self::download_recursive(sftp, &path, &new_dest)?;
}
} else {
let mut remote_file = sftp.open(src).map_err(io::Error::other)?;
let mut local_file = fs::File::create(dest)?;
io::copy(&mut remote_file, &mut local_file)?;
}
Ok(())
}
}