feat: basic functionality

This commit is contained in:
2026-06-09 19:09:57 +02:00
parent a3339f9d34
commit 04ed1f25ae
15 changed files with 5374 additions and 63 deletions
+144
View File
@@ -0,0 +1,144 @@
use crate::veprintln;
use anyhow::{anyhow, Context, Result};
use nix::unistd::{chroot, execve};
use std::collections::HashMap;
use std::path::Path;
/// Run a command in the chroot environment
pub fn run_chroot(
rootfs: &Path,
command: Option<Vec<String>>,
bind_rw_paths: &[std::path::PathBuf],
) -> Result<()> {
// Get TERM from host before chroot
let host_term = std::env::var("TERM").unwrap_or_else(|_| "xterm-256color".to_string());
// Set hostname in UTS namespace
if let Err(e) = crate::namespace::set_hostname("chroot") {
eprintln!("Warning: Failed to set hostname: {}", e);
}
// Change to root directory in chroot
chroot(rootfs).context("Failed to chroot")?;
// Now we're inside the chroot - set up environment based on chroot filesystem
// Determine shell path (check inside chroot, not host)
let shell = if Path::new("/bin/bash").exists() {
"/bin/bash"
} else if Path::new("/bin/sh").exists() {
"/bin/sh"
} else if Path::new("/usr/bin/bash").exists() {
"/usr/bin/bash"
} else if Path::new("/usr/bin/sh").exists() {
"/usr/bin/sh"
} else {
"/bin/sh" // Will fail with clear error if not present
};
// Set up environment variables (after chroot, so paths are correct)
let env = setup_environment(shell, &host_term);
// Determine the command to run
let (program, args) = match command {
Some(cmd) if !cmd.is_empty() => {
let program = cmd[0].clone();
let args = cmd
.iter()
.map(|s| {
std::ffi::CString::new(s.as_str())
.with_context(|| format!("Argument contains a null byte: {:?}", s))
})
.collect::<Result<Vec<_>>>()?;
(program, args)
}
_ => {
// Run shell (already determined above based on chroot filesystem)
let program = shell.to_string();
let args =
vec![std::ffi::CString::new(shell).context("Shell path contains a null byte")?];
(program, args)
}
};
// Build an explicit envp from setup_environment so the host environment
// is never inherited. execve takes this array directly; the host process
// environment is not touched at all.
let env_cstrings = env
.iter()
.map(|(k, v)| {
std::ffi::CString::new(format!("{}={}", k, v))
.with_context(|| format!("Environment variable contains a null byte: {}={}", k, v))
})
.collect::<Result<Vec<_>>>()?;
// Change to first bind_rw directory if available, otherwise /root, otherwise /
// bind_rw paths are mounted at /mnt/<basename> (see mount.rs setup_bind_rw)
let working_dir = if let Some(first_bind_rw) = bind_rw_paths.first() {
let dest_dir = Path::new("/mnt").join(first_bind_rw.file_name().unwrap_or_default());
if dest_dir.exists() {
dest_dir
} else if Path::new("/root").exists() {
Path::new("/root").to_path_buf()
} else {
Path::new("/").to_path_buf()
}
} else if Path::new("/root").exists() {
Path::new("/root").to_path_buf()
} else {
Path::new("/").to_path_buf()
};
std::env::set_current_dir(&working_dir).context("Failed to change to working directory")?;
// Print welcome message
veprintln!("Entering chroot at {}", rootfs.display());
for path in bind_rw_paths {
let basename = path
.file_name()
.map(|n| n.to_string_lossy())
.unwrap_or_default();
veprintln!("Read-write mount: /mnt/{}", basename);
}
veprintln!("Working directory: {}", working_dir.display());
// Check if the program exists
if !Path::new(&program).exists() {
// Try to find it in PATH
let found = env.get("PATH").and_then(|path| {
path.split(':')
.map(|p| std::path::PathBuf::from(p).join(&program))
.find(|p| p.exists())
});
if found.is_none() {
return Err(anyhow!("Program not found: {}", program));
}
}
// Exec the program directly with an explicit, isolated environment.
// execve never returns on success.
let program_cstr = std::ffi::CString::new(program.as_str()).context("Invalid program name")?;
let result = execve(&program_cstr, &args, &env_cstrings);
match result {
Ok(_) => Ok(()), // Never reached
Err(e) => Err(anyhow!("Failed to exec {}: {}", program, e)),
}
}
/// Setup default environment variables for chroot
/// Must be called AFTER chroot so paths are resolved inside the chroot
fn setup_environment(shell: &str, term: &str) -> HashMap<&'static str, String> {
let mut env = HashMap::new();
env.insert("HOME", "/root".to_string());
env.insert("USER", "root".to_string());
env.insert("SHELL", shell.to_string());
env.insert("TERM", term.to_string());
env.insert(
"PATH",
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin".to_string(),
);
env
}
+43
View File
@@ -0,0 +1,43 @@
use clap::Parser;
use std::path::PathBuf;
/// Enter chroot environments with Linux namespaces
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None, override_usage = "ecr [OPTIONS] <DISTRO[:VERSION]> -- [COMMAND]...")]
pub struct Args {
/// Distribution name (e.g., ubuntu, debian, arch, alpine, fedora)
#[arg(value_name = "DISTRO[:VERSION]")]
pub distro: String,
/// Target architecture
#[arg(short, long, value_name = "ARCH")]
pub arch: Option<String>,
/// Directory to overlay-mount (can be specified multiple times, default: current directory)
#[arg(long, value_name = "PATH")]
pub bind: Vec<PathBuf>,
/// Directory to bind-mount read-write at /mnt/<basename> (overrides regular bind, can be specified multiple times)
#[arg(long, value_name = "PATH")]
pub bind_rw: Vec<PathBuf>,
/// Download fresh tarball, ignore cache
#[arg(long)]
pub no_cache: bool,
/// Skip mounting any directory
#[arg(long)]
pub no_bind: bool,
/// Print diagnostic messages (URLs, manifest info, extraction steps, etc.)
#[arg(short = 'v', long)]
pub verbose: bool,
/// Command to run inside the chroot (default: interactive shell)
#[arg(
trailing_var_arg = true,
allow_hyphen_values = true,
value_name = "COMMAND"
)]
pub command: Vec<String>,
}
+41
View File
@@ -0,0 +1,41 @@
use anyhow::{Context, Result};
use serde::Deserialize;
#[derive(Debug, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub dns: Vec<String>,
}
impl Config {
pub fn load() -> Result<Self> {
let config_path = dirs::config_dir().map(|p| p.join("ecr.yaml"));
match config_path {
Some(path) if path.exists() => {
let content =
std::fs::read_to_string(&path).context("Failed to read config file")?;
let config: Config =
serde_yaml::from_str(&content).context("Failed to parse config file")?;
// Set defaults
let config = Config {
dns: if config.dns.is_empty() {
vec!["1.1.1.1".to_string()]
} else {
config.dns
},
};
Ok(config)
}
_ => {
// No config file, use defaults
Ok(Config {
dns: vec!["1.1.1.1".to_string()],
})
}
}
}
}
+520
View File
@@ -0,0 +1,520 @@
use anyhow::{anyhow, Context, Result};
/// Known distributions with optimized direct tarball downloads
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Distro {
Ubuntu,
Alpine,
}
/// Represents the source of the container image
#[derive(Debug, Clone, PartialEq)]
pub enum ImageSource {
/// Direct tarball download from known distro
DirectTarball {
distro: Distro,
version: Option<String>,
},
/// OCI/Docker registry image
OciImage {
registry: String,
repository: String,
tag: String,
architecture: String,
},
}
impl Distro {
pub fn from_name(name: &str) -> Result<Self> {
match name.to_lowercase().as_str() {
"ubuntu" => Ok(Distro::Ubuntu),
"alpine" => Ok(Distro::Alpine),
_ => Err(anyhow!("Unknown distribution: {}", name)),
}
}
}
/// Parse image reference and return the appropriate source
/// Supports:
/// - Simple distro names: ubuntu, debian, alpine
/// - Distro with version: ubuntu:noble, alpine:3.19
/// - OCI image references: docker://ubuntu:latest, quay.io/centos/centos:stream9
/// - Docker Hub shorthand: ubuntu:latest (without registry prefix)
pub fn parse_image_ref(input: &str, arch: &str) -> Result<ImageSource> {
let input = input.trim();
// Check for explicit docker:// or oci:// prefix
if let Some(rest) = input
.strip_prefix("docker://")
.or_else(|| input.strip_prefix("oci://"))
{
return parse_oci_ref(rest, arch);
}
// Check for registry prefix (contains /)
if input.contains('/') {
return parse_oci_ref(input, arch);
}
// Try to parse as known distro
let (name, version) = match input.split_once(':') {
Some((n, v)) => (n, Some(v.to_string())),
None => (input, None),
};
// Check if it's a known distro with optimized path
match name.to_lowercase().as_str() {
"ubuntu" | "alpine" => {
let distro = Distro::from_name(name)?;
Ok(ImageSource::DirectTarball { distro, version })
}
// Arch: use Docker image (has mirrors configured, unlike bootstrap tarball)
"arch" => {
let oci_arch = map_oci_arch(arch);
Ok(ImageSource::OciImage {
registry: "docker.io".to_string(),
repository: "library/archlinux".to_string(),
tag: version.unwrap_or_else(|| "latest".to_string()),
architecture: oci_arch,
})
}
// Special case: gentoo maps to gentoo/stage3 on Docker Hub (full rootfs)
"gentoo" => {
let oci_arch = map_oci_arch(arch);
Ok(ImageSource::OciImage {
registry: "docker.io".to_string(),
repository: "gentoo/stage3".to_string(),
tag: version.unwrap_or_else(|| "latest".to_string()),
architecture: oci_arch,
})
}
// Default to Docker Hub for unknown distros (debian, fedora, etc.)
_ => parse_oci_ref(input, arch),
}
}
/// Parse OCI image reference (registry/repo:tag or repo:tag)
fn parse_oci_ref(input: &str, arch: &str) -> Result<ImageSource> {
// Split off the tag. A ':' is a tag separator only when it appears after
// the last '/'; any ':' before a '/' is a port number
// (e.g. localhost:5000/image). With no '/' the single ':' is the tag.
let (without_tag, tag) = if let Some(slash_pos) = input.rfind('/') {
let after_slash = &input[slash_pos + 1..];
if let Some(colon_pos) = after_slash.find(':') {
(
&input[..slash_pos + 1 + colon_pos],
input[slash_pos + 1 + colon_pos + 1..].to_string(),
)
} else {
(input, "latest".to_string())
}
} else {
// No slash: bare "ubuntu" or "ubuntu:latest"
match input.split_once(':') {
Some((name, t)) => (name, t.to_string()),
None => (input, "latest".to_string()),
}
};
// Split without_tag into registry + repository.
// The first path component is the registry when it contains '.' or ':'
// (hostname / host:port) or is the literal "localhost".
let (registry, repository) = if let Some((first, rest)) = without_tag.split_once('/') {
if first.contains('.') || first.contains(':') || first == "localhost" {
(first.to_string(), rest.to_string())
} else {
// Org-qualified Docker Hub shorthand: "myorg/myimage"
("docker.io".to_string(), without_tag.to_string())
}
} else {
// Bare image name → Docker Hub library image
("docker.io".to_string(), format!("library/{}", without_tag))
};
// Map architecture to OCI standard
let oci_arch = map_oci_arch(arch);
Ok(ImageSource::OciImage {
registry,
repository,
tag,
architecture: oci_arch,
})
}
/// Map architecture to OCI standard names
pub fn map_oci_arch(arch: &str) -> String {
match arch {
"amd64" | "x86_64" => "amd64".to_string(),
"arm64" | "aarch64" => "arm64".to_string(),
"armhf" | "armv7" => "arm".to_string(),
"riscv64" => "riscv64".to_string(),
"ppc64el" | "ppc64le" => "ppc64le".to_string(),
"s390x" => "s390x".to_string(),
_ => arch.to_string(),
}
}
/// Map ecr architecture names to distro-specific names
pub fn map_arch(distro: Distro, arch: &str) -> String {
match distro {
Distro::Ubuntu => match arch {
"amd64" => "amd64".to_string(),
"arm64" => "arm64".to_string(),
"armhf" => "armhf".to_string(),
"riscv64" => "riscv64".to_string(),
"ppc64el" => "ppc64el".to_string(),
"s390x" => "s390x".to_string(),
_ => arch.to_string(),
},
Distro::Alpine => match arch {
"amd64" => "x86_64".to_string(),
"arm64" => "aarch64".to_string(),
"armhf" => "armv7".to_string(),
"riscv64" => "riscv64".to_string(),
"ppc64el" => "ppc64le".to_string(),
"s390x" => "s390x".to_string(),
_ => arch.to_string(),
},
}
}
/// Resolve the download URL for a known distro (optimized path)
pub fn resolve_distro_url(distro: &Distro, version: Option<&str>, arch: &str) -> Result<String> {
match distro {
Distro::Ubuntu => resolve_ubuntu_url(version, arch),
Distro::Alpine => resolve_alpine_url(version, arch),
}
}
/// Fetch the latest Ubuntu codename from a changelogs.ubuntu.com meta-release file.
/// Pass the LTS-only URL to get the latest LTS, or the full URL for the latest release.
fn fetch_ubuntu_codename(meta_release_url: &str) -> Result<String> {
let text = reqwest::blocking::get(meta_release_url)
.with_context(|| format!("Failed to fetch {}", meta_release_url))?
.text()
.with_context(|| format!("Failed to read {}", meta_release_url))?;
let mut current_dist: Option<String> = None;
let mut latest: Option<String> = None;
for line in text.lines() {
if let Some(dist) = line.strip_prefix("Dist: ") {
current_dist = Some(dist.trim().to_string());
} else if line.trim_start().starts_with("Supported: 1") {
if let Some(dist) = current_dist.take() {
latest = Some(dist);
}
}
}
latest.ok_or_else(|| {
anyhow!(
"Could not determine Ubuntu codename from {}",
meta_release_url
)
})
}
/// Fetch the current Alpine minirootfs version from latest-releases.yaml on the CDN.
/// The `latest-stable/` directory is a server-side symlink; the YAML file it contains
/// tells us the exact version number needed for the tarball filename.
/// Fetch the latest minirootfs version for a given Alpine CDN branch.
/// `branch` is the directory name on the CDN, e.g. `"latest-stable"` or `"v3.23"`.
fn fetch_alpine_version_from_branch(branch: &str, arch: &str) -> Result<String> {
#[derive(serde::Deserialize)]
struct AlpineRelease {
file: Option<String>,
version: Option<String>,
}
let url = format!(
"https://dl-cdn.alpinelinux.org/alpine/{}/releases/{}/latest-releases.yaml",
branch, arch
);
let text = reqwest::blocking::get(&url)
.with_context(|| {
format!(
"Failed to fetch Alpine latest-releases.yaml for branch {}",
branch
)
})?
.text()
.context("Failed to read Alpine latest-releases.yaml")?;
let releases: Vec<AlpineRelease> =
serde_yaml::from_str(&text).context("Failed to parse Alpine latest-releases.yaml")?;
for release in releases {
let is_minirootfs = release
.file
.as_deref()
.map(|f| f.contains("minirootfs"))
.unwrap_or(false);
if is_minirootfs {
// Prefer explicit `version:` field; fall back to parsing the filename.
// Filename format: alpine-minirootfs-3.23.0-x86_64.tar.gz
if let Some(v) = release.version {
return Ok(v);
}
if let Some(v) = release.file.as_deref().and_then(|f| f.split('-').nth(2)) {
return Ok(v.to_string());
}
}
}
Err(anyhow!(
"Could not find minirootfs entry in Alpine latest-releases.yaml for branch {}",
branch
))
}
fn fetch_alpine_latest_version(arch: &str) -> Result<String> {
fetch_alpine_version_from_branch("latest-stable", arch)
}
/// Resolve a `major.minor` Alpine series (e.g. `"3.24"`) to its current
/// patch release by querying the CDN branch `v{minor}`.
fn fetch_alpine_minor_version(minor: &str, arch: &str) -> Result<String> {
fetch_alpine_version_from_branch(&format!("v{}", minor), arch)
}
/// Resolve the canonical version string for a distro, performing a network lookup
/// only when the requested version is a floating alias (e.g. "latest", "lts").
/// The returned string is suitable for use as a stable cache key.
pub fn resolve_distro_version(
distro: &Distro,
version: Option<&str>,
arch: &str,
) -> Result<String> {
match distro {
Distro::Ubuntu => resolve_ubuntu_version(version),
Distro::Alpine => resolve_alpine_version(version, arch),
}
}
/// Build a map of YY.MM version strings to codenames by parsing the Ubuntu
/// meta-release file (e.g. "24.04" → "noble", "22.04" → "jammy").
/// Uses the full meta-release (not -lts) so non-LTS versions are also covered.
fn fetch_ubuntu_version_map() -> Result<std::collections::HashMap<String, String>> {
let text = reqwest::blocking::get("https://changelogs.ubuntu.com/meta-release")
.context("Failed to fetch Ubuntu meta-release")?
.text()
.context("Failed to read Ubuntu meta-release")?;
let mut map = std::collections::HashMap::new();
let mut current_dist: Option<String> = None;
for line in text.lines() {
if let Some(dist) = line.strip_prefix("Dist: ") {
current_dist = Some(dist.trim().to_string());
} else if let Some(version_str) = line.strip_prefix("Version: ") {
// Version field may be "22.04", "22.04 LTS", or "24.04.1 LTS".
// Normalise to YY.MM by taking the first two dot-separated components.
let raw = version_str.split_whitespace().next().unwrap_or("");
let normalised: String = {
let mut parts = raw.splitn(3, '.');
match (parts.next(), parts.next()) {
(Some(a), Some(b)) => format!("{}.{}", a, b),
_ => raw.to_string(),
}
};
if let Some(dist) = &current_dist {
map.insert(normalised, dist.clone());
}
}
}
Ok(map)
}
fn resolve_ubuntu_version(version: Option<&str>) -> Result<String> {
let version = version.unwrap_or("latest");
match version {
"latest" => fetch_ubuntu_codename("https://changelogs.ubuntu.com/meta-release"),
"lts" | "latest-lts" => {
fetch_ubuntu_codename("https://changelogs.ubuntu.com/meta-release-lts")
}
other => {
// If it looks like a YY.MM version number, resolve it to a codename
// via the meta-release file so the mapping never goes stale.
if is_ubuntu_version_number(other) {
let map = fetch_ubuntu_version_map()?;
map.get(other).cloned().ok_or_else(|| {
anyhow!(
"Unknown Ubuntu version '{}'. \
Use the codename directly (e.g. noble, jammy) or check \
https://changelogs.ubuntu.com/meta-release",
other
)
})
} else {
// Treat as a codename and pass through (e.g. "noble", "jammy")
Ok(other.to_string())
}
}
}
}
/// Returns true for strings of the form "YY.MM" (two numeric dot-separated components).
fn is_ubuntu_version_number(s: &str) -> bool {
let mut parts = s.splitn(3, '.');
matches!(
(parts.next(), parts.next(), parts.next()),
(Some(a), Some(b), None)
if !a.is_empty() && !b.is_empty()
&& a.chars().all(|c| c.is_ascii_digit())
&& b.chars().all(|c| c.is_ascii_digit())
)
}
fn resolve_alpine_version(version: Option<&str>, arch: &str) -> Result<String> {
let alpine_arch = map_arch(Distro::Alpine, arch);
Ok(match version.unwrap_or("latest") {
"latest" | "stable" => fetch_alpine_latest_version(&alpine_arch)?,
// edge is a rolling branch; query the CDN to get the current
// date-stamped version (e.g. "20250401") so the URL and cache key
// are correct. resolve_alpine_url maps this date string back to the
// "edge" CDN directory via the all-digits guard in its release match.
"edge" => fetch_alpine_version_from_branch("edge", &alpine_arch)?,
v => {
// "3.23" — one dot: major.minor series → fetch current patch from CDN
// "3.23.0" — two dots: full version already → pass through as-is
let dots = v.chars().filter(|&c| c == '.').count();
if dots == 1 {
fetch_alpine_minor_version(v, &alpine_arch)?
} else {
v.to_string()
}
}
})
}
fn resolve_ubuntu_url(version: Option<&str>, arch: &str) -> Result<String> {
let codename = resolve_ubuntu_version(version)?;
let arch = map_arch(Distro::Ubuntu, arch);
Ok(format!(
"https://cdimage.ubuntu.com/ubuntu-base/{}/daily/current/{}-base-{}.tar.gz",
codename, codename, arch
))
}
fn resolve_alpine_url(version: Option<&str>, arch: &str) -> Result<String> {
let version_str = version.unwrap_or("latest");
let alpine_arch = map_arch(Distro::Alpine, arch);
let version_num = resolve_alpine_version(Some(version_str), arch)?;
// Derive the CDN release directory from the version string.
// After resolve_distro_version runs, version_str may already be the
// *resolved* value rather than the original alias:
// "latest"/"stable" → e.g. "3.23.1" (dots present → v3.23 below)
// "edge" → e.g. "20250401" (all digits, no dots)
// "3.23" → e.g. "3.23.1" (dots present → v3.23 below)
// The all-digit check catches resolved edge dates and maps them back to
// the "edge" CDN directory.
let release = match version_str {
"latest" | "stable" => "latest-stable".to_string(),
"edge" => "edge".to_string(),
v if v.chars().all(|c| c.is_ascii_digit()) => "edge".to_string(),
other => {
// "3.23" or "3.23.1" → "v3.23"
let mut parts = other.splitn(3, '.');
let major = parts.next().unwrap_or("0");
let minor = parts.next().unwrap_or("0");
format!("v{}.{}", major, minor)
}
};
Ok(format!(
"https://dl-cdn.alpinelinux.org/alpine/{}/releases/{}/alpine-minirootfs-{}-{}.tar.gz",
release, alpine_arch, version_num, alpine_arch
))
}
#[cfg(test)]
mod tests {
use super::*;
fn oci(registry: &str, repository: &str, tag: &str) -> ImageSource {
ImageSource::OciImage {
registry: registry.to_string(),
repository: repository.to_string(),
tag: tag.to_string(),
architecture: "amd64".to_string(),
}
}
fn parse(input: &str) -> ImageSource {
parse_oci_ref(input, "amd64").expect("parse failed")
}
#[test]
fn test_bare_name() {
// "ubuntu" → Docker Hub library, tag=latest
let src = parse("ubuntu");
assert_eq!(src, oci("docker.io", "library/ubuntu", "latest"));
}
#[test]
fn test_name_with_tag() {
let src = parse("ubuntu:noble");
assert_eq!(src, oci("docker.io", "library/ubuntu", "noble"));
}
#[test]
fn test_org_repo() {
let src = parse("myorg/myimage:v2");
assert_eq!(src, oci("docker.io", "myorg/myimage", "v2"));
}
#[test]
fn test_registry_with_port_no_tag() {
// The colon in "localhost:5000" must NOT be treated as a tag separator
let src = parse("localhost:5000/myimage");
assert_eq!(src, oci("localhost:5000", "myimage", "latest"));
}
#[test]
fn test_registry_with_port_and_tag() {
let src = parse("localhost:5000/myimage:v1");
assert_eq!(src, oci("localhost:5000", "myimage", "v1"));
}
#[test]
fn test_registry_with_port_and_org() {
let src = parse("localhost:5000/org/myimage:v1");
assert_eq!(src, oci("localhost:5000", "org/myimage", "v1"));
}
#[test]
fn test_named_registry() {
let src = parse("quay.io/centos/centos:stream9");
assert_eq!(src, oci("quay.io", "centos/centos", "stream9"));
}
#[test]
fn test_docker_hub_fqdn() {
let src = parse("registry-1.docker.io/library/ubuntu:noble");
assert_eq!(src, oci("registry-1.docker.io", "library/ubuntu", "noble"));
}
/// After resolve_distro_version, alpine:edge carries a date string like
/// "20250401". resolve_alpine_url must still produce a URL under the
/// "edge" CDN directory, not a bogus "v20250401.0" directory.
#[test]
fn test_alpine_edge_resolved_date_uses_edge_directory() {
// Simulate the already-resolved version that main.rs stores after calling
// resolve_distro_version("edge", …).
let url = resolve_alpine_url(Some("20250401"), "amd64").unwrap();
assert!(
url.contains("/alpine/edge/"),
"expected URL to contain '/alpine/edge/' but got: {}",
url
);
assert!(
url.contains("minirootfs-20250401-"),
"expected URL to contain 'minirootfs-20250401-' but got: {}",
url
);
}
}
+703
View File
@@ -0,0 +1,703 @@
use crate::veprintln;
use anyhow::{anyhow, Context, Result};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use serde_json::Value;
use std::path::{Path, PathBuf};
use std::time::Duration;
use crate::distro::ImageSource;
/// Path of the digest sidecar file for a cached OCI image.
/// e.g. `foo.tar.gz` → `foo.tar.gz.digest`
pub fn digest_sidecar(cache_path: &Path) -> PathBuf {
let mut s = cache_path.as_os_str().to_owned();
s.push(".digest");
PathBuf::from(s)
}
/// Fetch the current manifest digest for an OCI tag without downloading any
/// layers. The registry returns the content-addressable digest in the
/// `Docker-Content-Digest` response header of any manifest GET.
pub fn fetch_oci_digest(registry: &str, repository: &str, tag: &str) -> Result<String> {
tokio::runtime::Runtime::new()
.context("Failed to create Tokio runtime")?
.block_on(fetch_oci_digest_async(registry, repository, tag))
}
async fn fetch_oci_digest_async(registry: &str, repository: &str, tag: &str) -> Result<String> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client")?;
let token = get_auth_token(&client, registry, repository).await?;
let manifest_url = if registry == "docker.io" {
format!(
"https://registry-1.docker.io/v2/{}/manifests/{}",
repository, tag
)
} else {
format!("https://{}/v2/{}/manifests/{}", registry, repository, tag)
};
let response = client
.get(&manifest_url)
.header(
"Accept",
"application/vnd.docker.distribution.manifest.list.v2+json, \
application/vnd.docker.distribution.manifest.v2+json, \
application/vnd.oci.image.index.v1+json, \
application/vnd.oci.image.manifest.v1+json",
)
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.context("Failed to fetch manifest for digest check")?;
if !response.status().is_success() {
return Err(anyhow!(
"Manifest fetch returned HTTP {}",
response.status()
));
}
response
.headers()
.get("Docker-Content-Digest")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.ok_or_else(|| anyhow!("Registry did not return a Docker-Content-Digest header"))
}
/// Download a container image (either direct tarball or OCI image)
pub fn download_image(source: &ImageSource, dest: &Path, arch: &str) -> Result<()> {
// Create a single Tokio runtime for all async I/O in this download.
// Previously each of the two sync wrappers (download_file_sync and
// download_oci_image) created their own runtime; this consolidates them.
let rt = tokio::runtime::Runtime::new().context("Failed to create Tokio runtime")?;
match source {
ImageSource::DirectTarball { distro, version } => {
let url = crate::distro::resolve_distro_url(distro, version.as_deref(), arch)?;
veprintln!("Resolved URL: {}", url);
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(anyhow!("Unsupported URL scheme: {}", url));
}
rt.block_on(download_file_async(&url, dest))
}
ImageSource::OciImage {
registry,
repository,
tag,
architecture,
} => {
veprintln!("Pulling OCI image: {}/{}:{}", registry, repository, tag);
rt.block_on(download_oci_image_async(
registry,
repository,
tag,
architecture,
dest,
))
}
}
}
async fn download_file_async(url: &str, dest: &Path) -> Result<()> {
veprintln!("Downloading: {}", url);
let client = Client::builder()
.timeout(Duration::from_secs(300))
.build()
.context("Failed to create HTTP client")?;
let response = client
.get(url)
.send()
.await
.context("Failed to start download")?;
if !response.status().is_success() {
return Err(anyhow!("Download failed: HTTP {}", response.status()));
}
let total_size = response.content_length().unwrap_or(0);
// Create parent directories
if let Some(parent) = dest.parent() {
tokio::fs::create_dir_all(parent)
.await
.context("Failed to create cache directory")?;
}
// Setup progress bar
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.unwrap()
.progress_chars("#>-"));
let temp_dest = dest.with_extension("partial");
// Download into the temp file, then atomically rename to the final path.
// On any failure after the temp file is created we remove it so stale
// .partial files don't accumulate in the cache directory.
let result: Result<()> = async {
let mut file = tokio::fs::File::create(&temp_dest)
.await
.context("Failed to create temporary file")?;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Download interrupted")?;
tokio::io::copy(&mut &chunk[..], &mut file).await?;
pb.inc(chunk.len() as u64);
}
pb.finish_and_clear();
tokio::fs::rename(&temp_dest, dest)
.await
.context("Failed to move partial download to final destination")?;
veprintln!("Download complete: {}", dest.display());
Ok(())
}
.await;
if result.is_err() {
// Best-effort cleanup; ignore errors (file may not exist if creation failed).
let _ = tokio::fs::remove_file(&temp_dest).await;
}
result
}
async fn download_oci_image_async(
registry: &str,
repository: &str,
tag: &str,
arch: &str,
dest: &Path,
) -> Result<()> {
let client = Client::builder()
.timeout(Duration::from_secs(300))
.build()
.context("Failed to create HTTP client")?;
let digest = try_download_oci_image(&client, registry, repository, tag, arch, dest).await?;
// Persist the manifest digest so the next run can skip the download when
// the tag hasn't changed (see fetch_oci_digest / digest_sidecar).
if !digest.is_empty() {
let _ = std::fs::write(digest_sidecar(dest), &digest);
}
Ok(())
}
async fn try_download_oci_image(
client: &Client,
registry: &str,
repository: &str,
tag: &str,
arch: &str,
dest: &Path,
) -> Result<String> {
// Get authentication token for the registry
let token = get_auth_token(client, registry, repository).await?;
// Construct manifest URL based on registry
let manifest_url = if registry == "docker.io" {
// Docker Hub uses registry-1.docker.io for API
format!(
"https://registry-1.docker.io/v2/{}/manifests/{}",
repository, tag
)
} else {
format!("https://{}/v2/{}/manifests/{}", registry, repository, tag)
};
veprintln!("Fetching manifest from: {}", manifest_url);
// Request both manifest list and single manifest types (including OCI formats)
let response = client
.get(&manifest_url)
.header("Accept", "application/vnd.docker.distribution.manifest.list.v2+json, application/vnd.docker.distribution.manifest.v2+json, application/vnd.oci.image.index.v1+json, application/vnd.oci.image.manifest.v1+json")
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.context("Failed to get image manifest")?;
if !response.status().is_success() {
return Err(anyhow!(
"Failed to get manifest: HTTP {}",
response.status()
));
}
// Capture the tag-level digest before consuming the response body.
// This is the content-addressable identity of the manifest (or manifest
// list) at this tag — used to detect whether the tag has moved since the
// last download.
let tag_digest = response
.headers()
.get("Docker-Content-Digest")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_default();
let body = response.text().await?;
let manifest: Value = serde_json::from_str(&body).context("Failed to parse manifest JSON")?;
// Check if this is a manifest list (multi-arch)
let layers = if let Some(manifests) = manifest["manifests"].as_array() {
// This is a manifest list - find the right architecture
veprintln!("Got manifest list with {} manifests", manifests.len());
let manifest_entry = manifests
.iter()
.find(|m| {
let m_arch = m["platform"]["architecture"].as_str().unwrap_or("");
let m_os = m["platform"]["os"].as_str().unwrap_or("");
m_arch == arch && m_os == "linux"
})
.ok_or_else(|| {
// List available architectures in the error message
let available: Vec<&str> = manifests
.iter()
.filter_map(|m| m["platform"]["architecture"].as_str())
.collect();
anyhow!(
"No manifest found for architecture '{}'. Available: {}",
arch,
available.join(", ")
)
})?;
let manifest_digest = manifest_entry["digest"]
.as_str()
.ok_or_else(|| anyhow!("No digest in manifest entry"))?;
veprintln!("Found manifest digest: {}", manifest_digest);
// Now get the actual manifest for this architecture
let arch_manifest_url = if registry == "docker.io" {
format!(
"https://registry-1.docker.io/v2/{}/manifests/{}",
repository, manifest_digest
)
} else {
format!(
"https://{}/v2/{}/manifests/{}",
registry, repository, manifest_digest
)
};
let arch_response = client
.get(&arch_manifest_url)
.header("Accept", "application/vnd.docker.distribution.manifest.v2+json, application/vnd.oci.image.manifest.v1+json")
.header("Authorization", format!("Bearer {}", token))
.send()
.await
.context("Failed to get architecture manifest")?;
if !arch_response.status().is_success() {
return Err(anyhow!(
"Failed to get arch manifest: HTTP {}",
arch_response.status()
));
}
let arch_body = arch_response.text().await?;
let arch_manifest: Value = serde_json::from_str(&arch_body)
.context("Failed to parse architecture manifest JSON")?;
// Get all layers from architecture manifest (in order)
arch_manifest["layers"]
.as_array()
.ok_or_else(|| anyhow!("No layers in architecture manifest"))?
.clone()
} else {
// Single architecture manifest - get all layers in order
manifest["layers"]
.as_array()
.ok_or_else(|| anyhow!("No layers in manifest"))?
.clone()
};
veprintln!("Found {} layers to download", layers.len());
// Download all layer blobs into a temp directory
let temp_dir = tempfile::tempdir()?;
// layer_names[i] is the filename used both for the downloaded blob and in
// layers.manifest. The extension is derived from the layer's mediaType so
// extract_oci_layer can dispatch on the filename without relying on
// magic-byte detection (which would misfire on a zstd blob named .tar.gz).
let mut layer_names: Vec<String> = Vec::with_capacity(layers.len());
for (i, layer) in layers.iter().enumerate() {
let layer_digest = layer["digest"]
.as_str()
.ok_or_else(|| anyhow!("No digest in layer"))?;
let media_type = layer["mediaType"].as_str().unwrap_or("");
let ext = media_type_to_extension(media_type);
let layer_name = format!("layer_{}.{}", i, ext);
veprintln!(
"Fetching layer {}/{}: {}",
i + 1,
layers.len(),
layer_digest
);
let blob_url = if registry == "docker.io" {
format!(
"https://registry-1.docker.io/v2/{}/blobs/{}",
repository, layer_digest
)
} else {
format!(
"https://{}/v2/{}/blobs/{}",
registry, repository, layer_digest
)
};
let layer_path = temp_dir.path().join(&layer_name);
let label = format!("Layer {}/{} ({})", i + 1, layers.len(), layer_name);
download_blob_with_auth(client, &blob_url, &layer_path, &token, &label).await?;
layer_names.push(layer_name);
}
// Write the layer index so extract.rs knows the order
{
use std::io::Write;
let mut manifest_file = std::fs::File::create(temp_dir.path().join("layers.manifest"))?;
for name in &layer_names {
writeln!(manifest_file, "{}", name)?;
}
}
// Bundle layers.manifest + all layer blobs into a single .tar.gz cache file
// using the tar + flate2 crates — no external `tar` binary required.
let temp_dest = dest.with_extension("tar.partial");
let bundle_result: Result<()> = (|| {
use flate2::{write::GzEncoder, Compression};
let out_file =
std::fs::File::create(&temp_dest).context("Failed to create temporary bundle file")?;
let mut builder = tar::Builder::new(GzEncoder::new(out_file, Compression::default()));
builder
.append_path_with_name(temp_dir.path().join("layers.manifest"), "layers.manifest")
.context("Failed to add layers.manifest to bundle")?;
for name in &layer_names {
builder
.append_path_with_name(temp_dir.path().join(name), name)
.with_context(|| format!("Failed to add {} to bundle", name))?;
}
// into_inner finalises the tar end-of-archive marker and returns the
// GzEncoder; finish() flushes and closes the gzip stream.
builder
.into_inner()
.context("Failed to finalise tar archive")?
.finish()
.context("Failed to finalise gzip stream")?;
std::fs::rename(&temp_dest, dest).context("Failed to move bundle to cache destination")
})();
if bundle_result.is_err() {
// Best-effort cleanup; ignore errors (file may not exist if creation failed).
let _ = std::fs::remove_file(&temp_dest);
}
bundle_result?;
Ok(tag_digest)
}
async fn get_auth_token(client: &Client, registry: &str, repository: &str) -> Result<String> {
// Docker Hub has a well-known, stable auth endpoint.
if registry == "docker.io" {
let url = format!(
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:{}:pull",
repository
);
veprintln!("Getting auth token from: {}", url);
return fetch_bearer_token(client, &url).await;
}
// For every other registry follow the OCI Distribution Spec §4.2:
// 1. Probe GET /v2/ unauthenticated.
// 2. Read the WWW-Authenticate challenge from the 401 response.
// 3. Build the token URL from the advertised realm + service.
let probe_url = format!("https://{}/v2/", registry);
let probe = client
.get(&probe_url)
.send()
.await
.with_context(|| format!("Failed to probe registry at {}", probe_url))?;
match probe.status().as_u16() {
200 => {
// Registry requires no authentication (e.g. local insecure registry).
return Ok(String::new());
}
401 => {}
other => {
// Unexpected status; proceed without a token and let the manifest
// request fail with a more descriptive error.
eprintln!(
"Warning: registry probe returned HTTP {}; trying without auth",
other
);
return Ok(String::new());
}
}
// Parse the WWW-Authenticate challenge.
let www_auth = probe
.headers()
.get("WWW-Authenticate")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
anyhow!(
"Registry {} returned 401 but no WWW-Authenticate header",
registry
)
})?;
let (realm, service) = parse_www_authenticate(www_auth)
.ok_or_else(|| anyhow!("Could not parse WWW-Authenticate header: {}", www_auth))?;
// The probe scope is generic; override it with the per-repository pull scope.
let token_url = format!(
"{}?service={}&scope=repository:{}:pull",
realm, service, repository
);
veprintln!("Getting auth token from: {}", token_url);
fetch_bearer_token(client, &token_url).await
}
/// Parse a `Bearer realm="...",service="..."[,scope="..."]` header value.
/// Returns `(realm, service)` on success.
fn parse_www_authenticate(header: &str) -> Option<(String, String)> {
let params = header.strip_prefix("Bearer ")?;
let mut realm: Option<String> = None;
let mut service: Option<String> = None;
// Split on commas that lie outside quoted strings.
let mut start = 0;
let mut in_quotes = false;
let bytes = params.as_bytes();
let mut parts: Vec<&str> = Vec::new();
for i in 0..bytes.len() {
match bytes[i] {
b'"' => in_quotes = !in_quotes,
b',' if !in_quotes => {
parts.push(params[start..i].trim());
start = i + 1;
}
_ => {}
}
}
parts.push(params[start..].trim());
for part in parts {
if let Some(val) = part.strip_prefix("realm=") {
realm = Some(val.trim_matches('"').to_string());
} else if let Some(val) = part.strip_prefix("service=") {
service = Some(val.trim_matches('"').to_string());
}
}
Some((realm?, service?))
}
/// Fetch a bearer token from a fully-constructed token URL.
async fn fetch_bearer_token(client: &Client, token_url: &str) -> Result<String> {
let response = client
.get(token_url)
.send()
.await
.context("Failed to request auth token")?;
if !response.status().is_success() {
// Registry may allow anonymous pulls; proceed with an empty token.
return Ok(String::new());
}
let body = response.text().await?;
let json: Value = serde_json::from_str(&body).context("Failed to parse auth token response")?;
// OCI spec uses "token"; Docker also emits "access_token".
json["token"]
.as_str()
.or_else(|| json["access_token"].as_str())
.map(|s| s.to_string())
.ok_or_else(|| anyhow!("No token field in auth response: {}", body))
}
async fn download_blob_with_auth(
client: &Client,
url: &str,
dest: &Path,
token: &str,
label: &str,
) -> Result<()> {
let mut request = client.get(url);
if !token.is_empty() {
request = request.header("Authorization", format!("Bearer {}", token));
}
let response = request.send().await.context("Failed to start download")?;
if !response.status().is_success() {
return Err(anyhow!(
"Failed to download blob: HTTP {}",
response.status()
));
}
let total_size = response.content_length().unwrap_or(0);
let pb = ProgressBar::new(total_size);
pb.set_style(
ProgressStyle::default_bar()
.template("{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.unwrap()
.progress_chars("#>-"),
);
pb.set_message(label.to_string());
let mut file = tokio::fs::File::create(dest)
.await
.context("Failed to create destination file")?;
let mut stream = response.bytes_stream();
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk")?;
file.write_all(&chunk)
.await
.context("Failed to write chunk")?;
pb.inc(chunk.len() as u64);
}
file.flush().await?;
pb.finish_and_clear();
Ok(())
}
/// Map an OCI layer mediaType to the appropriate file extension.
/// The extension is stored in layers.manifest and used by extract_oci_layer
/// for compression dispatch, so it must accurately reflect the blob encoding.
fn media_type_to_extension(media_type: &str) -> &'static str {
match media_type {
// Docker V2 schema 2
"application/vnd.docker.image.rootfs.diff.tar.gzip" => "tar.gz",
// OCI image spec
"application/vnd.oci.image.layer.v1.tar+gzip" => "tar.gz",
"application/vnd.oci.image.layer.v1.tar+zstd" => "tar.zst",
"application/vnd.oci.image.layer.v1.tar" => "tar",
// Non-distributable variants (same encoding, different semantics)
"application/vnd.oci.image.layer.nondistributable.v1.tar+gzip" => "tar.gz",
"application/vnd.oci.image.layer.nondistributable.v1.tar+zstd" => "tar.zst",
"application/vnd.oci.image.layer.nondistributable.v1.tar" => "tar",
// Unknown: fall back to gzip (historically the most common format)
// and let magic-byte detection in extract_oci_layer handle it.
_ => "tar.gz",
}
}
#[cfg(test)]
mod tests {
use super::parse_www_authenticate;
#[test]
fn test_docker_hub_challenge() {
let hdr = r#"Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:library/ubuntu:pull""#;
let (realm, service) = parse_www_authenticate(hdr).unwrap();
assert_eq!(realm, "https://auth.docker.io/token");
assert_eq!(service, "registry.docker.io");
}
#[test]
fn test_quay_challenge() {
let hdr = r#"Bearer realm="https://quay.io/v2/auth",service="quay.io",scope="repository:centos/centos:pull""#;
let (realm, service) = parse_www_authenticate(hdr).unwrap();
assert_eq!(realm, "https://quay.io/v2/auth");
assert_eq!(service, "quay.io");
}
#[test]
fn test_ghcr_challenge() {
let hdr = r#"Bearer realm="https://ghcr.io/token",service="ghcr.io",scope="repository:owner/repo:pull""#;
let (realm, service) = parse_www_authenticate(hdr).unwrap();
assert_eq!(realm, "https://ghcr.io/token");
assert_eq!(service, "ghcr.io");
}
#[test]
fn test_gcr_challenge() {
let hdr = r#"Bearer realm="https://gcr.io/v2/token",service="gcr.io""#;
let (realm, service) = parse_www_authenticate(hdr).unwrap();
assert_eq!(realm, "https://gcr.io/v2/token");
assert_eq!(service, "gcr.io");
}
#[test]
fn test_value_with_comma_in_scope() {
// scope field itself contains no comma but realm/service values may
// contain other special chars — ensure the quoted-comma splitter works
let hdr = r#"Bearer realm="https://example.com/auth",service="example.com",scope="repository:a/b:pull""#;
let (realm, service) = parse_www_authenticate(hdr).unwrap();
assert_eq!(realm, "https://example.com/auth");
assert_eq!(service, "example.com");
}
#[test]
fn test_not_bearer_returns_none() {
assert!(parse_www_authenticate("Basic realm=\"registry\"").is_none());
}
#[test]
fn test_media_type_to_extension_known_types() {
use super::media_type_to_extension;
assert_eq!(
media_type_to_extension("application/vnd.docker.image.rootfs.diff.tar.gzip"),
"tar.gz"
);
assert_eq!(
media_type_to_extension("application/vnd.oci.image.layer.v1.tar+gzip"),
"tar.gz"
);
assert_eq!(
media_type_to_extension("application/vnd.oci.image.layer.v1.tar+zstd"),
"tar.zst"
);
assert_eq!(
media_type_to_extension("application/vnd.oci.image.layer.v1.tar"),
"tar"
);
}
#[test]
fn test_media_type_to_extension_unknown_falls_back_to_gz() {
use super::media_type_to_extension;
assert_eq!(media_type_to_extension(""), "tar.gz");
assert_eq!(media_type_to_extension("text/plain"), "tar.gz");
}
}
+312
View File
@@ -0,0 +1,312 @@
use crate::veprintln;
use anyhow::{Context, Result};
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
/// Extract a tarball to the specified directory
pub fn extract_tarball(tarball: &Path, dest: &Path) -> Result<()> {
let filename = tarball.file_name().and_then(|n| n.to_str()).unwrap_or("");
// OCI bundle files are always named "oci-…" by generate_cache_filename.
// Use the filename prefix as a zero-I/O dispatch signal instead of
// scanning the archive for a layers.manifest entry (which required
// reading the entire compressed archive twice for large OCI images).
if filename.starts_with("oci-") {
veprintln!("Detected multi-layer OCI image, extracting layers...");
return extract_multi_layer_oci(tarball, dest);
}
let file = File::open(tarball)
.with_context(|| format!("Failed to open tarball: {}", tarball.display()))?;
let reader = BufReader::new(file);
// Detect compression format from filename
if filename.ends_with(".tar.gz") || filename.ends_with(".tgz") {
extract_gz(reader, dest)?;
} else if filename.ends_with(".tar.xz") || filename.ends_with(".txz") {
extract_xz(reader, dest)?;
} else if filename.ends_with(".tar.zst") || filename.ends_with(".tar.zstd") {
extract_zst(reader, dest)?;
} else if filename.ends_with(".tar") {
extract_tar(reader, dest)?;
} else {
// Try to detect from magic bytes
let mut magic = [0u8; 6];
let mut peek_reader = BufReader::new(File::open(tarball)?);
peek_reader.read_exact(&mut magic)?;
match magic {
[0x1f, 0x8b, ..] => {
// gzip magic
drop(peek_reader);
let file = File::open(tarball)?;
extract_gz(BufReader::new(file), dest)?;
}
[0xfd, b'7', b'z', b'X', b'Z', 0x00] => {
// xz magic
drop(peek_reader);
let file = File::open(tarball)?;
extract_xz(BufReader::new(file), dest)?;
}
[0x28, 0xb5, 0x2f, 0xfd, ..] => {
// zstd magic
drop(peek_reader);
let file = File::open(tarball)?;
extract_zst(BufReader::new(file), dest)?;
}
_ => {
// Assume uncompressed tar
drop(peek_reader);
let file = File::open(tarball)?;
extract_tar(BufReader::new(file), dest)?;
}
}
}
Ok(())
}
/// Extract a multi-layer OCI image
fn extract_multi_layer_oci(tarball: &Path, dest: &Path) -> Result<()> {
let filename = tarball.file_name().and_then(|n| n.to_str()).unwrap_or("");
// Extract the outer OCI bundle (layers.manifest + layer_N.tar.gz files)
// into a temp directory. This is our own format, not an OCI layer, so
// plain unpack is fine here.
let temp_dir = tempfile::tempdir().context("Failed to create temp directory for OCI layers")?;
let file = File::open(tarball)
.with_context(|| format!("Failed to open OCI tarball: {}", tarball.display()))?;
let reader = BufReader::new(file);
if filename.ends_with(".tar.gz") || filename.ends_with(".tgz") {
tar::Archive::new(flate2::read::GzDecoder::new(reader))
.unpack(temp_dir.path())
.context("Failed to unpack OCI bundle")?;
} else if filename.ends_with(".tar.xz") || filename.ends_with(".txz") {
tar::Archive::new(xz2::read::XzDecoder::new(reader))
.unpack(temp_dir.path())
.context("Failed to unpack OCI bundle")?;
} else if filename.ends_with(".tar.zst") || filename.ends_with(".tar.zstd") {
tar::Archive::new(zstd::stream::read::Decoder::new(reader)?)
.unpack(temp_dir.path())
.context("Failed to unpack OCI bundle")?;
} else {
tar::Archive::new(reader)
.unpack(temp_dir.path())
.context("Failed to unpack OCI bundle")?;
}
// Read the layers manifest
let manifest = std::fs::read_to_string(temp_dir.path().join("layers.manifest"))
.context("Failed to read layers.manifest")?;
// Apply each layer in order with full whiteout handling.
for layer_name in manifest.lines() {
let layer_path = temp_dir.path().join(layer_name);
if !layer_path.exists() {
continue;
}
veprintln!("Extracting layer: {}", layer_name);
extract_oci_layer(&layer_path, dest)?;
}
Ok(())
}
/// Decompress and extract one OCI layer tarball into `dest`, honouring
/// whiteout markers. Compression is inferred from the filename then from
/// magic bytes.
fn extract_oci_layer(layer_path: &Path, dest: &Path) -> Result<()> {
use std::io::Read;
let layer_name = layer_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
let file = File::open(layer_path)
.with_context(|| format!("Failed to open layer: {}", layer_path.display()))?;
let reader = BufReader::new(file);
if layer_name.ends_with(".tar.gz") || layer_name.ends_with(".tgz") {
extract_archive_with_whiteouts(
tar::Archive::new(flate2::read::GzDecoder::new(reader)),
dest,
)
} else if layer_name.ends_with(".tar.xz") || layer_name.ends_with(".txz") {
extract_archive_with_whiteouts(tar::Archive::new(xz2::read::XzDecoder::new(reader)), dest)
} else if layer_name.ends_with(".tar.zst") || layer_name.ends_with(".tar.zstd") {
extract_archive_with_whiteouts(
tar::Archive::new(zstd::stream::read::Decoder::new(reader)?),
dest,
)
} else {
// Fall back to magic-byte detection
let mut magic = [0u8; 6];
let mut peek = BufReader::new(File::open(layer_path)?);
let _ = peek.read_exact(&mut magic); // short reads are fine for detection
drop(peek);
match magic {
[0x1f, 0x8b, ..] => extract_archive_with_whiteouts(
tar::Archive::new(flate2::read::GzDecoder::new(BufReader::new(File::open(
layer_path,
)?))),
dest,
),
[0xfd, b'7', b'z', b'X', b'Z', 0x00] => extract_archive_with_whiteouts(
tar::Archive::new(xz2::read::XzDecoder::new(BufReader::new(File::open(
layer_path,
)?))),
dest,
),
[0x28, 0xb5, 0x2f, 0xfd, ..] => extract_archive_with_whiteouts(
tar::Archive::new(zstd::stream::read::Decoder::new(BufReader::new(
File::open(layer_path)?,
))?),
dest,
),
_ => extract_archive_with_whiteouts(
tar::Archive::new(BufReader::new(File::open(layer_path)?)),
dest,
),
}
}
}
/// Apply one OCI layer archive to `dest`, interpreting Docker whiteout markers:
///
/// - `.wh.<name>` — Delete `<name>` from a lower layer that was already
/// extracted into `dest`.
/// - `.wh..wh..opq` — Opaque whiteout: the directory that contains this entry
/// is new in this layer; delete everything already in that
/// directory from lower layers before applying new content.
///
/// All other entries are extracted normally via `Entry::unpack_in`.
fn extract_archive_with_whiteouts<R: std::io::Read>(
mut archive: tar::Archive<R>,
dest: &Path,
) -> Result<()> {
archive.set_preserve_permissions(true);
archive.set_preserve_ownerships(false);
archive.set_unpack_xattrs(false);
for entry in archive.entries().context("Failed to iterate tar entries")? {
let mut entry = entry.context("Failed to read tar entry")?;
// Clone the path before any mutable borrow of entry (needed for unpack_in)
let path = entry.path().context("Invalid tar entry path")?.into_owned();
let filename = path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_default();
if filename == ".wh..wh..opq" {
// Opaque whiteout: clear all previously-extracted content in the
// parent directory so only this layer's content is visible.
let parent = path.parent().unwrap_or(Path::new(""));
let dest_dir = dest.join(parent);
if dest_dir.symlink_metadata().is_ok() {
for child in std::fs::read_dir(&dest_dir)
.with_context(|| format!("Failed to read {}", dest_dir.display()))?
{
let child = child?;
let child_path = child.path();
remove_path(&child_path).with_context(|| {
format!("Opaque whiteout: failed to remove {}", child_path.display())
})?;
}
}
// Do not extract the .wh..wh..opq marker itself.
} else if let Some(real_name) = filename.strip_prefix(".wh.") {
// Regular whiteout: delete the named path from lower layers.
let parent = path.parent().unwrap_or(Path::new(""));
let target = dest.join(parent).join(real_name);
// symlink_metadata (lstat) does not follow symlinks, so a dangling
// symlink is correctly detected and removed rather than silently skipped.
if target.symlink_metadata().is_ok() {
remove_path(&target)
.with_context(|| format!("Whiteout: failed to remove {}", target.display()))?;
}
// Do not extract the .wh.* marker itself.
} else {
entry
.unpack_in(dest)
.with_context(|| format!("Failed to extract {}", path.display()))?;
}
}
Ok(())
}
/// Remove a path: uses remove_dir_all for real directories, remove_file for
/// everything else (regular files, symlinks — including symlinks-to-dirs).
fn remove_path(path: &Path) -> std::io::Result<()> {
// symlink_metadata does not follow symlinks, so a symlink-to-dir correctly
// reports file_type().is_symlink() rather than is_dir().
let meta = std::fs::symlink_metadata(path)?;
if meta.is_dir() {
std::fs::remove_dir_all(path)
} else {
std::fs::remove_file(path)
}
}
fn extract_gz<R: std::io::Read>(reader: R, dest: &Path) -> Result<()> {
let gz_decoder = flate2::read::GzDecoder::new(reader);
let mut archive = tar::Archive::new(gz_decoder);
archive.set_preserve_permissions(true);
archive.set_preserve_ownerships(false);
archive.set_unpack_xattrs(false);
archive
.unpack(dest)
.context("Failed to extract gzip archive")?;
Ok(())
}
fn extract_xz<R: std::io::Read>(reader: R, dest: &Path) -> Result<()> {
let xz_decoder = xz2::read::XzDecoder::new(reader);
let mut archive = tar::Archive::new(xz_decoder);
archive.set_preserve_permissions(true);
archive.set_preserve_ownerships(false);
archive.set_unpack_xattrs(false);
archive
.unpack(dest)
.context("Failed to extract xz archive")?;
Ok(())
}
fn extract_zst<R: std::io::Read>(reader: R, dest: &Path) -> Result<()> {
let zst_decoder = zstd::Decoder::new(reader)?;
let mut archive = tar::Archive::new(zst_decoder);
archive.set_preserve_permissions(true);
archive.set_preserve_ownerships(false);
archive.set_unpack_xattrs(false);
archive
.unpack(dest)
.context("Failed to extract zstd archive")?;
Ok(())
}
fn extract_tar<R: std::io::Read>(reader: R, dest: &Path) -> Result<()> {
let mut archive = tar::Archive::new(reader);
archive.set_preserve_permissions(true);
archive.set_preserve_ownerships(false);
archive.set_unpack_xattrs(false);
archive
.unpack(dest)
.context("Failed to extract tar archive")?;
Ok(())
}
+300 -2
View File
@@ -1,3 +1,301 @@
fn main() {
println!("Hello, world!");
mod chroot;
mod cli;
mod config;
mod distro;
mod download;
mod extract;
mod mount;
mod namespace;
mod qemu;
mod verbose;
/// Print to stderr only when --verbose / -v is active.
#[macro_export]
macro_rules! veprintln {
($($arg:tt)*) => {
if $crate::verbose::is_verbose() {
eprintln!($($arg)*);
}
};
}
use anyhow::Result;
use clap::Parser;
use cli::Args;
use config::Config;
use distro::{
map_arch, parse_image_ref, resolve_distro_url, resolve_distro_version, Distro, ImageSource,
};
use download::{digest_sidecar, download_image, fetch_oci_digest};
use extract::extract_tarball;
fn main() -> Result<()> {
let args = Args::parse();
// Initialise verbosity before anything else so all downstream code can use veprintln!.
verbose::set(args.verbose);
// Load config file
let config = Config::load()?;
// Get architecture
let host_arch = get_host_arch();
let arch = args.arch.clone().unwrap_or_else(|| host_arch.clone());
// Parse image reference
let image_source = parse_image_ref(&args.distro, &arch)?;
// For DirectTarball, resolve floating aliases ("latest", "lts") to a concrete
// version string *before* computing the cache key. This ensures we cache as
// e.g. "ubuntu-noble-amd64" rather than "ubuntu-latest-amd64", so a future
// release automatically gets its own cache entry.
let image_source = match image_source {
ImageSource::DirectTarball { distro, version } => {
let resolved = resolve_distro_version(&distro, version.as_deref(), &arch)?;
ImageSource::DirectTarball {
distro,
version: Some(resolved),
}
}
other => other,
};
// Determine cache directory and filename
let cache_dir = dirs::cache_dir()
.expect("Could not determine cache directory")
.join("ecr");
let cache_filename = generate_cache_filename(&image_source, &arch);
let cache_path = cache_dir.join(&cache_filename);
// OCI images with a floating tag (":latest") need a freshness check:
// fetch the current manifest digest from the registry and compare it
// against the digest stored from the last download. Only re-pull when
// the digest has actually changed. On a network error we fall back to
// the cached image with a warning rather than hard-failing.
let oci_digest_changed = if cache_path.exists() {
if let ImageSource::OciImage {
registry,
repository,
tag,
..
} = &image_source
{
if tag == "latest" {
match fetch_oci_digest(registry, repository, tag) {
Ok(current) => {
let stored = std::fs::read_to_string(digest_sidecar(&cache_path)).ok();
stored.as_deref() != Some(current.trim())
}
Err(e) => {
eprintln!(
"Warning: could not check image freshness ({}); using cache",
e
);
false
}
}
} else {
false // pinned tags are assumed immutable
}
} else {
false
}
} else {
false // cache absent — download triggered by !cache_path.exists() below
};
// Download if not cached, --no-cache, or the remote digest has moved
if args.no_cache || !cache_path.exists() || oci_digest_changed {
std::fs::create_dir_all(&cache_dir)?;
download_image(&image_source, &cache_path, &arch)?;
} else {
veprintln!("Using cached tarball: {}", cache_path.display());
}
// Check QEMU if foreign architecture
if arch != host_arch {
qemu::check_binfmt(&arch)?;
}
// Check user namespace availability
namespace::check_user_namespace()?;
// Process bind paths - use current directory if none specified
let cwd = std::env::current_dir().expect("Could not get current directory");
let bind_paths: Vec<std::path::PathBuf> = if args.bind.is_empty() && !args.no_bind {
vec![cwd.clone()]
} else {
args.bind.clone()
};
// --no-bind means "skip mounting any directory". Combining it with an
// explicit --bind-rw is contradictory; error rather than silently ignoring
// the flag the user asked for.
if args.no_bind && !args.bind_rw.is_empty() {
return Err(anyhow::anyhow!(
"--no-bind and --bind-rw cannot be used together: \
--no-bind skips all mounts, including read-write ones"
));
}
let bind_rw_paths: Vec<std::path::PathBuf> = args.bind_rw.clone();
// Create temp directory for extraction
let temp_dir = tempfile::tempdir()?;
let rootfs = temp_dir.path().to_path_buf();
veprintln!("Extracting to: {}", rootfs.display());
extract_tarball(&cache_path, &rootfs)?;
// Prepare data for the closure
let bind_paths_clone = bind_paths.clone();
let bind_rw_paths_clone = bind_rw_paths.clone();
let args_clone = args.clone();
let rootfs_clone = rootfs.clone();
let dns_clone = config.dns.clone();
// Run in namespace
let result = namespace::setup_namespaces(move || -> Result<()> {
// Setup mounts - overlay_temps must be kept alive for overlay to work
let overlay_temps = mount::setup_mounts(
&rootfs_clone,
&bind_paths_clone,
&bind_rw_paths_clone,
&args_clone,
)?;
// Write resolv.conf with DNS from config
write_resolv_conf(&rootfs_clone, &dns_clone)?;
// Run chroot
let command = if args_clone.command.is_empty() {
None
} else {
Some(args_clone.command.clone())
};
let result = chroot::run_chroot(&rootfs_clone, command, &bind_rw_paths_clone);
// Keep overlay_temps alive until chroot exits
drop(overlay_temps);
result
});
// Cleanup happens automatically via tempfile
match &result {
Ok(_) => veprintln!("Cleanup complete."),
Err(e) => eprintln!("Error: {}", e),
}
result
}
/// Generate a cache filename based on the image source
fn generate_cache_filename(source: &ImageSource, arch: &str) -> String {
match source {
ImageSource::DirectTarball { distro, version } => {
let distro_name = match distro {
Distro::Ubuntu => "ubuntu",
Distro::Alpine => "alpine",
};
let distro_arch = map_arch(*distro, arch);
// Get extension from URL
let url = resolve_distro_url(distro, version.as_deref(), arch).unwrap_or_default();
let ext = get_tarball_extension(&url);
format!(
"{}-{}-{}.{}",
distro_name,
version.as_deref().unwrap_or("latest"),
distro_arch,
ext
)
}
ImageSource::OciImage {
registry,
repository,
tag,
architecture,
} => {
// Sanitize for filename
let safe_registry = registry.replace(['.', ':'], "_");
let safe_repo = repository.replace(['/', ':'], "_");
format!(
"oci-{}-{}-{}-{}.tar.gz",
safe_registry, safe_repo, tag, architecture
)
}
}
}
fn get_tarball_extension(url: &str) -> &str {
// Extract extension from URL (e.g., .tar.gz, .tar.xz, .tar.zst)
if url.ends_with(".tar.zst") {
"tar.zst"
} else if url.ends_with(".tar.xz") {
"tar.xz"
} else if url.ends_with(".tar.gz") {
"tar.gz"
} else if url.ends_with(".tar.bz2") {
"tar.bz2"
} else {
"tar.gz" // default
}
}
fn get_host_arch() -> String {
// Use the uname(2) syscall directly — no subprocess, no PATH dependency,
// no panic-on-missing-binary. This gives the runtime machine string
// (e.g. "x86_64", "aarch64") exactly as `uname -m` would, which is what
// we need for the QEMU check. std::env::consts::ARCH is compile-time and
// would be wrong if the binary itself is running under emulation.
let utsname = nix::sys::utsname::uname()
.expect("uname(2) syscall failed — cannot determine host architecture");
let machine = utsname.machine().to_string_lossy();
match machine.as_ref() {
"x86_64" => "amd64".to_string(),
"aarch64" => "arm64".to_string(),
"armv7l" | "armv7" => "armhf".to_string(),
"riscv64" => "riscv64".to_string(),
"ppc64le" => "ppc64el".to_string(),
"s390x" => "s390x".to_string(),
other => other.to_string(),
}
}
fn write_resolv_conf(rootfs: &std::path::Path, dns: &[String]) -> Result<()> {
let resolv_conf = rootfs.join("etc/resolv.conf");
// Create /etc if it doesn't exist
if let Some(parent) = resolv_conf.parent() {
std::fs::create_dir_all(parent)?;
}
// Copy host's resolv.conf if dns is empty, otherwise use provided DNS
let content = if dns.is_empty() {
// Try to copy from host
match std::fs::read_to_string("/etc/resolv.conf") {
Ok(host_resolv) => host_resolv,
Err(_) => "nameserver 1.1.1.1\nnameserver 8.8.8.8\n".to_string(),
}
} else {
let mut c = dns
.iter()
.map(|s| format!("nameserver {}", s))
.collect::<Vec<_>>()
.join("\n");
c.push('\n');
c
};
// Remove any existing file or symlink before writing so that std::fs::write
// always creates a plain file. Without this, an absolute symlink such as
// /etc/resolv.conf -> /run/systemd/resolve/stub-resolv.conf would cause the
// write to follow the symlink through the *host* root (chroot() has not been
// called yet) and corrupt the host's DNS configuration.
let _ = std::fs::remove_file(&resolv_conf); // ignore ENOENT
std::fs::write(&resolv_conf, content)?;
Ok(())
}
+278
View File
@@ -0,0 +1,278 @@
use anyhow::{anyhow, Context, Result};
use nix::mount::{mount, MsFlags};
use std::path::Path;
use tempfile::TempDir;
/// Escape a path for use as an overlayfs mount option value.
///
/// The overlayfs kernel driver uses `,` as its option delimiter and `\` as
/// the escape character (Linux ≥ 5.1, commit 6b2d09a). A bare comma in a
/// path would silently split the option string at the wrong boundary and
/// produce a cryptic kernel error; a bare backslash would be mis-interpreted
/// as starting an escape sequence.
fn escape_overlay_path(path: &Path) -> Result<String> {
let s = path.to_str().ok_or_else(|| {
anyhow!(
"Overlay path '{}' contains non-UTF-8 characters",
path.display()
)
})?;
// Backslashes must be escaped before commas to avoid double-escaping.
Ok(s.replace('\\', "\\\\").replace(',', "\\,"))
}
use crate::cli::Args;
/// Setup all required mounts inside the chroot
/// Returns a TempDir that must be kept alive for the duration of the chroot
pub fn setup_mounts(
rootfs: &Path,
bind_paths: &[std::path::PathBuf],
bind_rw_paths: &[std::path::PathBuf],
args: &Args,
) -> Result<Vec<TempDir>> {
// Keep all overlay temp dirs alive
let mut overlay_temps: Vec<TempDir> = Vec::new();
// Make all mounts private to avoid propagation to host
if let Err(e) = mount(
None::<&str>,
"/",
None::<&str>,
MsFlags::MS_PRIVATE | MsFlags::MS_REC,
None::<&str>,
) {
eprintln!("Warning: Failed to make mounts private: {}", e);
}
// Mount /proc
mount_proc(rootfs)?;
// Setup /dev by bind mounting from host
mount_dev(rootfs)?;
// Mount /dev/pts
mount_devpts(rootfs)?;
// Try to mount /sys (may fail in some environments)
if let Err(e) = mount_sys(rootfs) {
eprintln!("Warning: Could not mount /sys: {}", e);
}
// Setup overlay mounts for bind paths (read-only via overlay)
if !args.no_bind {
for bind_path in bind_paths {
// Skip if this path is also in bind_rw (bind_rw takes precedence)
if !bind_rw_paths.contains(bind_path) {
let temp = setup_overlay(rootfs, bind_path)?;
overlay_temps.push(temp);
}
}
}
// Setup read-write bind mounts (these override regular bind for same paths)
for bind_rw_path in bind_rw_paths {
setup_bind_rw(rootfs, bind_rw_path)?;
}
Ok(overlay_temps)
}
fn mount_proc(rootfs: &Path) -> Result<()> {
let proc_path = rootfs.join("proc");
std::fs::create_dir_all(&proc_path)?;
let flags = MsFlags::MS_NOSUID | MsFlags::MS_NOEXEC | MsFlags::MS_NODEV;
mount(Some("proc"), &proc_path, Some("proc"), flags, None::<&str>)
.with_context(|| format!("Failed to mount proc at {}", proc_path.display()))?;
Ok(())
}
fn mount_sys(rootfs: &Path) -> Result<()> {
let sys_path = rootfs.join("sys");
std::fs::create_dir_all(&sys_path)?;
// Bind mount /sys from host as read-only
mount(
Some("/sys"),
&sys_path,
None::<&str>,
MsFlags::MS_BIND | MsFlags::MS_REC,
None::<&str>,
)
.with_context(|| format!("Failed to bind mount sys at {}", sys_path.display()))?;
// Remount as read-only with full security flags.
// MS_BIND | MS_REMOUNT does NOT inherit the original mount's flags; every
// desired flag must be listed explicitly. /proc uses the same set.
mount(
Some(&sys_path),
&sys_path,
None::<&str>,
MsFlags::MS_BIND
| MsFlags::MS_REMOUNT
| MsFlags::MS_RDONLY
| MsFlags::MS_NOSUID
| MsFlags::MS_NODEV
| MsFlags::MS_NOEXEC,
None::<&str>,
)
.with_context(|| {
format!(
"Failed to remount sys as read-only at {}",
sys_path.display()
)
})?;
Ok(())
}
fn mount_dev(rootfs: &Path) -> Result<()> {
let dev_path = rootfs.join("dev");
std::fs::create_dir_all(&dev_path)?;
// Bind mount /dev from host
mount(
Some("/dev"),
&dev_path,
None::<&str>,
MsFlags::MS_BIND | MsFlags::MS_REC,
None::<&str>,
)
.with_context(|| format!("Failed to bind mount dev at {}", dev_path.display()))?;
Ok(())
}
fn mount_devpts(rootfs: &Path) -> Result<()> {
let devpts_path = rootfs.join("dev/pts");
std::fs::create_dir_all(&devpts_path)?;
let flags = MsFlags::MS_NOSUID | MsFlags::MS_NOEXEC;
mount(
Some("devpts"),
&devpts_path,
Some("devpts"),
flags,
None::<&str>,
)
.with_context(|| format!("Failed to mount devpts at {}", devpts_path.display()))?;
Ok(())
}
/// Setup overlay mount for workspace directory
/// Returns a TempDir that must be kept alive for the overlay to work
fn setup_overlay(rootfs: &Path, source: &Path) -> Result<TempDir> {
let basename = source
.file_name()
.ok_or_else(|| anyhow!("Invalid bind path"))?
.to_string_lossy();
let mount_point = rootfs.join("root").join(basename.as_ref());
std::fs::create_dir_all(&mount_point)?;
// Create temp directories for overlay
let temp_dir = tempfile::tempdir()?;
let upper_dir = temp_dir.path().join("upper");
let work_dir = temp_dir.path().join("work");
std::fs::create_dir_all(&upper_dir)?;
std::fs::create_dir_all(&work_dir)?;
// Create overlay mount options
let lowerdir = source.canonicalize()?;
let upperdir = upper_dir.canonicalize()?;
let workdir = work_dir.canonicalize()?;
let options = format!(
"lowerdir={},upperdir={},workdir={}",
escape_overlay_path(&lowerdir)?,
escape_overlay_path(&upperdir)?,
escape_overlay_path(&workdir)?,
);
mount(
Some("overlay"),
&mount_point,
Some("overlay"),
MsFlags::empty(),
Some(options.as_str()),
)
.with_context(|| format!("Failed to mount overlay at {}", mount_point.display()))?;
// Return temp_dir so caller can keep it alive
Ok(temp_dir)
}
/// Setup read-write bind mount
fn setup_bind_rw(rootfs: &Path, source: &Path) -> Result<()> {
let basename = source
.file_name()
.ok_or_else(|| anyhow!("Invalid bind-rw path"))?
.to_string_lossy();
let mount_point = rootfs.join("mnt").join(basename.as_ref());
std::fs::create_dir_all(&mount_point)?;
let source = source.canonicalize()?;
mount(
Some(&source),
&mount_point,
None::<&str>,
MsFlags::MS_BIND | MsFlags::MS_REC,
None::<&str>,
)
.with_context(|| {
format!(
"Failed to bind mount {} at {}",
source.display(),
mount_point.display()
)
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::escape_overlay_path;
use std::path::Path;
#[test]
fn plain_path_unchanged() {
assert_eq!(
escape_overlay_path(Path::new("/home/user/project")).unwrap(),
"/home/user/project"
);
}
#[test]
fn comma_in_path_escaped() {
assert_eq!(
escape_overlay_path(Path::new("/home/user/my,project")).unwrap(),
"/home/user/my\\,project"
);
}
#[test]
fn backslash_escaped_before_comma() {
// Backslash must be doubled first so a path like "a\,b" becomes
// "a\\\,b" and not "a\,b" (which would look like an escaped comma).
assert_eq!(
escape_overlay_path(Path::new("/a\\,b")).unwrap(),
"/a\\\\\\,b"
);
}
#[test]
fn multiple_commas_all_escaped() {
assert_eq!(
escape_overlay_path(Path::new("/a,b,c")).unwrap(),
"/a\\,b\\,c"
);
}
}
+413
View File
@@ -0,0 +1,413 @@
use anyhow::{anyhow, Context, Result};
use nix::sched::{clone, CloneFlags};
use nix::sys::signal::Signal;
use nix::unistd::{getgid, getuid, Pid};
/// RAII wrapper that closes a raw file descriptor on drop.
/// Guarantees all pipe fds are closed on every return path, including
/// clone() and setup_user_namespace() failures.
struct AutoCloseFd(i32);
impl AutoCloseFd {
fn raw(&self) -> i32 {
self.0
}
}
impl Drop for AutoCloseFd {
fn drop(&mut self) {
unsafe {
libc::close(self.0);
}
}
}
/// Clone flags for namespace creation
const CLONE_FLAGS: CloneFlags = CloneFlags::CLONE_NEWUSER
.union(CloneFlags::CLONE_NEWPID)
.union(CloneFlags::CLONE_NEWNS)
.union(CloneFlags::CLONE_NEWUTS);
/// Check if user namespaces are available
pub fn check_user_namespace() -> Result<()> {
// Check kernel.unprivileged_userns_clone on systems that have it
if let Ok(content) = std::fs::read_to_string("/proc/sys/kernel/unprivileged_userns_clone") {
if content.trim() == "0" {
return Err(anyhow!(
"User namespaces not available\n\n\
Enable with:\n\
sysctl -w kernel.unprivileged_userns_clone=1\n\n\
Or check AppArmor profile restrictions."
));
}
}
// Check max_user_namespaces
if let Ok(content) = std::fs::read_to_string("/proc/sys/user/max_user_namespaces") {
if let Ok(max) = content.trim().parse::<u32>() {
if max == 0 {
return Err(anyhow!(
"User namespaces not available\n\n\
Enable with:\n\
sysctl -w user.max_user_namespaces=10000"
));
}
}
}
Ok(())
}
/// Setup namespaces and run the provided function inside them
pub fn setup_namespaces<F>(f: F) -> Result<()>
where
F: FnOnce() -> Result<()> + Send + 'static,
{
// sync pipe: parent signals child to proceed after UID/GID mapping
// done pipe: child signals parent it has finished (or exec'd the shell)
// error pipe: child writes the anyhow error chain to the parent on failure.
// The write end is O_CLOEXEC so it is automatically closed when execvp
// succeeds — the parent then reads EOF and knows there was no error.
//
// All six fds are wrapped in AutoCloseFd so they are closed on every return
// path, including clone() and setup_user_namespace() failures.
let (parent_read, parent_write, child_read, child_write, error_read, error_write);
unsafe {
let mut fds: [i32; 2] = [-1, -1];
if libc::pipe(fds.as_mut_ptr()) != 0 {
return Err(anyhow!("Failed to create sync pipe"));
}
parent_read = AutoCloseFd(fds[0]);
parent_write = AutoCloseFd(fds[1]);
// parent_read/write auto-closed if subsequent pipes fail ↑
if libc::pipe(fds.as_mut_ptr()) != 0 {
return Err(anyhow!("Failed to create done pipe"));
}
child_read = AutoCloseFd(fds[0]);
child_write = AutoCloseFd(fds[1]);
// O_CLOEXEC on the write end: if execve succeeds the kernel closes cw
// atomically, the parent's read(child_read) gets EOF immediately, and
// waitpid becomes the real wait. The done-pipe is then only used on
// the error path (f() returned Err before execve was reached).
libc::fcntl(child_write.raw(), libc::F_SETFD, libc::FD_CLOEXEC);
if libc::pipe(fds.as_mut_ptr()) != 0 {
return Err(anyhow!("Failed to create error pipe"));
}
error_read = AutoCloseFd(fds[0]);
error_write = AutoCloseFd(fds[1]);
// Same treatment for error_write: auto-closed on exec (no error),
// written explicitly on the error path before the child exits.
libc::fcntl(error_write.raw(), libc::F_SETFD, libc::FD_CLOEXEC);
}
// Stack for the child process
let stack_size = 1024 * 1024;
let mut stack = vec![0u8; stack_size];
// Wrap f in Option to allow taking it once inside the child closure
let mut f = Some(f);
// Extract raw fds for the child closure. The child is a clone of the
// parent process and gets its own copies of all open fds; the parent's
// AutoCloseFd wrappers independently manage the parent's copies.
let pr = parent_read.raw();
let pw = parent_write.raw();
let cr = child_read.raw();
let cw = child_write.raw();
let er = error_read.raw();
let ew = error_write.raw();
// Clone with new namespaces
let pid = unsafe {
clone(
Box::new(move || {
// Close unused pipe ends in the child
libc::close(pw);
libc::close(cr);
libc::close(er);
// Wait for parent to set up UID/GID mappings
let mut buf = [0u8; 1];
libc::read(pr, buf.as_mut_ptr() as *mut libc::c_void, 1);
libc::close(pr);
// Run the function
let result = if let Some(func) = f.take() {
func()
} else {
Err(anyhow!("Function already called"))
};
// On failure, write the full error chain to the error pipe
// before signalling done, so the parent can reconstruct it.
if let Err(ref e) = result {
let msg = format!("{:#}", e);
let bytes = msg.as_bytes();
libc::write(ew, bytes.as_ptr() as *const libc::c_void, bytes.len());
}
libc::close(ew);
// Signal completion
libc::write(cw, c"done".as_ptr() as *const libc::c_void, 4);
libc::close(cw);
if result.is_ok() {
0
} else {
1
}
}),
&mut stack,
CLONE_FLAGS,
Some(Signal::SIGCHLD as i32),
)
}
.context("Failed to clone with new namespaces")?;
// clone() failure: all six AutoCloseFds drop here, closing every fd. ✓
// Parent: drop the child-side ends now that clone has succeeded.
// The child process has its own copies; dropping here closes the parent's.
drop(parent_read);
drop(child_write);
drop(error_write);
// Set up UID/GID mappings for the child
// setup_user_namespace failure: parent_write, child_read, error_read
// auto-closed by AutoCloseFd drop. ✓
setup_user_namespace(pid)?;
// Signal child to proceed
unsafe {
libc::write(parent_write.raw(), c"go".as_ptr() as *const libc::c_void, 2);
}
drop(parent_write);
// Wait for child to complete (or the exec'd shell to exit)
let mut buf = [0u8; 4];
unsafe {
libc::read(child_read.raw(), buf.as_mut_ptr() as *mut libc::c_void, 4);
}
drop(child_read);
// Read the error message written by the child, if any.
// error_write was either closed explicitly (on error) or auto-closed via
// O_CLOEXEC (on successful exec), so this read always terminates.
let child_error: Option<String> = unsafe {
let mut error_bytes = Vec::new();
let mut tmp = [0u8; 4096];
loop {
let n = libc::read(
error_read.raw(),
tmp.as_mut_ptr() as *mut libc::c_void,
tmp.len(),
);
if n <= 0 {
break;
}
error_bytes.extend_from_slice(&tmp[..n as usize]);
}
if error_bytes.is_empty() {
None
} else {
Some(String::from_utf8_lossy(&error_bytes).into_owned())
}
};
drop(error_read);
// Wait for child process
let status = nix::sys::wait::waitpid(pid, None)?;
match status {
nix::sys::wait::WaitStatus::Exited(_, 0) => Ok(()),
nix::sys::wait::WaitStatus::Exited(_, code) => {
if let Some(msg) = child_error {
Err(anyhow!("{}", msg))
} else {
Err(anyhow!("Child process exited with code {}", code))
}
}
nix::sys::wait::WaitStatus::Signaled(_, sig, _) => {
Err(anyhow!("Child process killed by signal {:?}", sig))
}
_ => Ok(()),
}
}
/// Set up UID/GID mappings for user namespace
fn setup_user_namespace(pid: Pid) -> Result<()> {
let uid = getuid();
let gid = getgid();
// Get the subordinate UID/GID ranges from /etc/subuid and /etc/subgid
// For unprivileged users, we need to use these ranges
let (sub_uid_start, sub_uid_count) = get_subuid_range(uid)?;
let (sub_gid_start, sub_gid_count) = get_subgid_range(gid)?;
// We map UID 0 (root inside the namespace) to the host user, then map
// IDs 1..sub_uid_count-1 to the subordinate range. A count of 0 or 1
// leaves no subordinate IDs to map and indicates a malformed /etc/subuid.
if sub_uid_count < 2 {
return Err(anyhow!(
"subuid count {} for uid {} is too small (need at least 2); \
check /etc/subuid",
sub_uid_count,
uid
));
}
if sub_gid_count < 2 {
return Err(anyhow!(
"subgid count {} for gid {} is too small (need at least 2); \
check /etc/subgid",
sub_gid_count,
gid
));
}
// Allow setgroups so apt and other tools can drop privileges
let setgroups_path = format!("/proc/{}/setgroups", pid);
std::fs::write(&setgroups_path, "allow\n")
.with_context(|| format!("Failed to write {}", setgroups_path))?;
// Use newuidmap and newgidmap for setting up mappings
// These are setuid binaries that allow unprivileged users to map subuid/subgid ranges
let pid_str = pid.to_string();
// newuidmap format: newuidmap pid ns_start host_start count ...
// Map current user to root (0), then subordinate UIDs starting from 1
let uid_result = std::process::Command::new("newuidmap")
.arg(&pid_str)
.arg("0")
.arg(uid.to_string())
.arg("1")
.arg("1")
.arg(sub_uid_start.to_string())
.arg((sub_uid_count - 1).to_string())
.status()
.context("Failed to execute newuidmap")?;
if !uid_result.success() {
return Err(anyhow!(
"newuidmap failed - ensure subuid entry exists in /etc/subuid"
));
}
// newgidmap format: newgidmap pid ns_start host_start count ...
// Map current group to root (0), then subordinate GIDs starting from 1
let gid_result = std::process::Command::new("newgidmap")
.arg(&pid_str)
.arg("0")
.arg(gid.to_string())
.arg("1")
.arg("1")
.arg(sub_gid_start.to_string())
.arg((sub_gid_count - 1).to_string())
.status()
.context("Failed to execute newgidmap")?;
if !gid_result.success() {
return Err(anyhow!(
"newgidmap failed - ensure subgid entry exists in /etc/subgid"
));
}
Ok(())
}
/// Get subordinate UID range for a user from /etc/subuid
fn get_subuid_range(uid: nix::unistd::Uid) -> Result<(u32, u32)> {
let content = std::fs::read_to_string("/etc/subuid").context("Failed to read /etc/subuid")?;
let username =
users::get_user_by_uid(uid.as_raw()).map(|u| u.name().to_string_lossy().to_string());
for line in content.lines() {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() >= 3 {
// Check if this line matches our user (by name or UID)
let matches = parts[0] == username.as_deref().unwrap_or("")
|| parts[0].parse::<u32>().ok() == Some(uid.as_raw());
if matches {
let start: u32 = parts[1].parse().context("Invalid subuid start")?;
let count: u32 = parts[2].parse().context("Invalid subuid count")?;
return Ok((start, count));
}
}
}
Err(anyhow!(
"No subuid entry found for user {} (uid {}). \
Add one to /etc/subuid, e.g.:\n {}:100000:65536",
username.as_deref().unwrap_or("<unknown>"),
uid,
username.as_deref().unwrap_or(&uid.to_string()),
))
}
/// Get subordinate GID range for a group from /etc/subgid
fn get_subgid_range(gid: nix::unistd::Gid) -> Result<(u32, u32)> {
let content = std::fs::read_to_string("/etc/subgid").context("Failed to read /etc/subgid")?;
let groupname =
users::get_group_by_gid(gid.as_raw()).map(|g| g.name().to_string_lossy().to_string());
for line in content.lines() {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() >= 3 {
// Check if this line matches our group (by name or GID)
let matches = parts[0] == groupname.as_deref().unwrap_or("")
|| parts[0].parse::<u32>().ok() == Some(gid.as_raw());
if matches {
let start: u32 = parts[1].parse().context("Invalid subgid start")?;
let count: u32 = parts[2].parse().context("Invalid subgid count")?;
return Ok((start, count));
}
}
}
Err(anyhow!(
"No subgid entry found for group {} (gid {}). \
Add one to /etc/subgid, e.g.:\n {}:100000:65536",
groupname.as_deref().unwrap_or("<unknown>"),
gid,
groupname.as_deref().unwrap_or(&gid.to_string()),
))
}
/// Set hostname in UTS namespace
pub fn set_hostname(distro: &str) -> Result<()> {
use nix::unistd::sethostname;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// Seed from both the current time and the PID so each invocation gets a
// distinct suffix even when called in rapid succession. Time alone is
// not sufficient: truncating nanoseconds to u8 in a tight loop produces
// the same byte every iteration.
let mut hasher = DefaultHasher::new();
std::time::SystemTime::now().hash(&mut hasher);
std::process::id().hash(&mut hasher);
let mut state = hasher.finish();
let chars = b"abcdefghijklmnopqrstuvwxyz0123456789";
let random_suffix: String = (0..6)
.map(|_| {
// Knuth multiplicative LCG — each step advances the full 64-bit state.
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
chars[(state >> 33) as usize % chars.len()] as char
})
.collect();
let hostname = format!("ecr-{}-{}", distro, random_suffix);
sethostname(&hostname).with_context(|| format!("Failed to set hostname to {}", hostname))?;
Ok(())
}
+37
View File
@@ -0,0 +1,37 @@
use crate::veprintln;
use anyhow::{anyhow, Result};
use std::path::Path;
/// Check if binfmt_misc is registered for the target architecture
pub fn check_binfmt(arch: &str) -> Result<()> {
let qemu_arch = map_arch_to_qemu(arch);
let binfmt_path = format!("/proc/sys/fs/binfmt_misc/qemu-{}", qemu_arch);
if !Path::new(&binfmt_path).exists() {
return Err(anyhow!(
"binfmt_misc not registered for {}\n\n\
Install QEMU user emulation:\n\
Ubuntu/Debian: sudo apt install qemu-user-static\n\
Arch: sudo pacman -S qemu-user-static-binfmt\n\
Alpine: sudo apk add qemu-user-static",
arch
));
}
veprintln!("QEMU binfmt_misc registered for {}", arch);
Ok(())
}
/// Map ecr architecture names to QEMU binary names
fn map_arch_to_qemu(arch: &str) -> &str {
match arch {
"amd64" | "x86_64" => "x86_64",
"arm64" | "aarch64" => "aarch64",
"armhf" | "armv7" => "arm",
"riscv64" => "riscv64",
"ppc64el" | "ppc64le" => "ppc64le",
"s390x" => "s390x",
_ => arch,
}
}
+11
View File
@@ -0,0 +1,11 @@
use std::sync::atomic::{AtomicBool, Ordering};
static VERBOSE: AtomicBool = AtomicBool::new(false);
pub fn set(v: bool) {
VERBOSE.store(v, Ordering::Relaxed);
}
pub fn is_verbose() -> bool {
VERBOSE.load(Ordering::Relaxed)
}