context: refactor context command running
All checks were successful
CI / build (push) Successful in 1m55s

This commit is contained in:
2025-12-17 10:05:47 +01:00
parent 35eaaaf93a
commit 106c61a096
4 changed files with 88 additions and 139 deletions

View File

@@ -1,9 +1,8 @@
use super::api::{CommandRunner, ContextDriver};
/// SSH context: execute commands over an SSH connection
/// Context driver: Copies over SFTP with ssh2, executes commands over ssh2 channels
use super::api::ContextDriver;
use log::debug;
use std::fs;
/// SSH context: execute commands over an SSH connection
/// Context driver: Copies over SFTP with ssh2
/// Command runner: Executes commands over an ssh2 channel (with PTY)
use std::io::{self, Read};
use std::net::TcpStream;
#[cfg(unix)]
@@ -67,16 +66,6 @@ impl ContextDriver for SshDriver {
Ok(remote_dest)
}
fn create_runner(&self, program: String) -> Box<dyn CommandRunner> {
Box::new(SshRunner {
host: self.host.clone(),
user: self.user.clone(),
port: self.port,
program,
args: Vec::new(),
})
}
fn retrieve_path(&self, src: &Path, dest: &Path) -> io::Result<()> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let sftp = sess.sftp().map_err(io::Error::other)?;
@@ -101,6 +90,71 @@ impl ContextDriver for SshDriver {
Ok(files)
}
fn run(&self, program: &str, args: &[String]) -> io::Result<std::process::ExitStatus> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
// Construct command line
let mut cmd_line = program.to_string();
for arg in args {
cmd_line.push(' ');
cmd_line.push_str(arg); // TODO: escape
}
debug!("Executing SSH command: {}", cmd_line);
channel
.request_pty("xterm", None, None)
.map_err(|e| io::Error::other(format!("Failed to request PTY: {}", e)))?;
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout_stream = channel.stream(0);
let mut stdout = io::stdout();
io::copy(&mut stdout_stream, &mut stdout)?;
channel.wait_close().map_err(io::Error::other)?;
let code = channel.exit_status().unwrap_or(-1);
Ok(ExitStatus::from_raw(code))
}
fn run_output(&self, program: &str, args: &[String]) -> io::Result<std::process::Output> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
// Construct command line
let mut cmd_line = program.to_string();
for arg in args {
cmd_line.push(' ');
cmd_line.push_str(arg); // TODO: escape
}
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout = Vec::new();
channel.read_to_end(&mut stdout)?;
let mut stderr = Vec::new();
channel.stderr().read_to_end(&mut stderr)?;
channel.wait_close().map_err(io::Error::other)?;
#[cfg(unix)]
let status = std::process::ExitStatus::from_raw(channel.exit_status().unwrap_or(-1));
#[cfg(not(unix))]
let status = {
panic!("SSH output capture only supported on Unix");
};
Ok(std::process::Output {
status,
stdout,
stderr,
})
}
fn prepare_work_dir(&self) -> io::Result<String> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let mut channel = sess.channel_session().map_err(io::Error::other)?;
@@ -167,81 +221,3 @@ impl SshDriver {
Ok(())
}
}
pub struct SshRunner {
pub host: String,
pub user: Option<String>,
pub port: Option<u16>,
pub program: String,
pub args: Vec<String>,
}
impl SshRunner {
fn prepare_channel(&self) -> io::Result<(ssh2::Channel, String)> {
let sess = connect_ssh(&self.host, self.user.as_deref(), self.port)?;
let channel = sess.channel_session().map_err(io::Error::other)?;
// Construct command line
let mut cmd_line = self.program.clone();
for arg in &self.args {
cmd_line.push(' ');
cmd_line.push_str(arg); // TODO: escape
}
Ok((channel, cmd_line))
}
}
impl CommandRunner for SshRunner {
fn add_arg(&mut self, arg: String) {
self.args.push(arg);
}
fn status(&mut self) -> io::Result<ExitStatus> {
let (mut channel, cmd_line) = self.prepare_channel()?;
debug!("Executing SSH command: {}", cmd_line);
channel
.request_pty("xterm", None, None)
.map_err(|e| io::Error::other(format!("Failed to request PTY: {}", e)))?;
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout_stream = channel.stream(0);
let mut stdout = io::stdout();
io::copy(&mut stdout_stream, &mut stdout)?;
channel.wait_close().map_err(io::Error::other)?;
let code = channel.exit_status().unwrap_or(-1);
Ok(ExitStatus::from_raw(code))
}
fn output(&mut self) -> io::Result<std::process::Output> {
let (mut channel, cmd_line) = self.prepare_channel()?;
channel.exec(&cmd_line).map_err(io::Error::other)?;
let mut stdout = Vec::new();
channel.read_to_end(&mut stdout)?;
let mut stderr = Vec::new();
channel.stderr().read_to_end(&mut stderr)?;
channel.wait_close().map_err(io::Error::other)?;
#[cfg(unix)]
let status = std::process::ExitStatus::from_raw(channel.exit_status().unwrap_or(-1));
#[cfg(not(unix))]
let status = {
panic!("SSH output capture only supported on Unix");
};
Ok(std::process::Output {
status,
stdout,
stderr,
})
}
}