From e9a55c9053b0f8694918e7a677732513052b3e3e Mon Sep 17 00:00:00 2001 From: Yuedong Wu Date: Thu, 11 Jun 2026 18:53:32 +0800 Subject: [PATCH 1/2] feat(server): support TLS certificate hot-reload Signed-off-by: Yuedong Wu --- Cargo.lock | 87 ++ crates/openshell-server/Cargo.toml | 2 + crates/openshell-server/src/lib.rs | 19 +- crates/openshell-server/src/tls.rs | 908 +++++++++++++++++++- crates/openshell-server/tests/common/mod.rs | 2 +- 5 files changed, 970 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d0cd77f85..b484fd03f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,15 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "arc-swap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207" +dependencies = [ + "rustversion", +] + [[package]] name = "argon2" version = "0.5.3" @@ -1623,6 +1632,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + [[package]] name = "futures" version = "0.3.32" @@ -2363,6 +2381,26 @@ dependencies = [ "web-time", ] +[[package]] +name = "inotify" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533e68a5842e734946fe159fb03fc9bbbb254f590dd0d8ad321ae5ff7beca2c1" +dependencies = [ + "bitflags", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "inout" version = "0.1.4" @@ -2632,6 +2670,26 @@ version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4933f3f57a8e9d9da04db23fb153356ecaf00cbd14aee46279c33dc80925c37" +[[package]] +name = "kqueue" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "273c0752728918e0ac4976f2b275b6fefb9ecd400585dec929419f3844cd87b5" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07293a4e297ac234359b510362495713f75ea345d5307140414f20c69ffeb087" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "kube" version = "0.90.0" @@ -3150,6 +3208,33 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "notify" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" +dependencies = [ + "bitflags", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio 1.2.0", + "notify-types", + "walkdir", + "windows-sys 0.60.2", +] + +[[package]] +name = "notify-types" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42b8cfee0e339a0337359f3c88165702ac6e600dc01c0cc9579a92d62b08477a" +dependencies = [ + "bitflags", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -3695,6 +3780,7 @@ name = "openshell-server" version = "0.0.0" dependencies = [ "anyhow", + "arc-swap", "async-trait", "axum 0.8.9", "bytes", @@ -3716,6 +3802,7 @@ dependencies = [ "metrics", "metrics-exporter-prometheus", "miette", + "notify", "openshell-bootstrap", "openshell-core", "openshell-driver-docker", diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 0b7e3a97e..61d5ed41b 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -94,6 +94,8 @@ ipnet = "2" tempfile = "3" rustix = { workspace = true } x509-parser = "0.16" +arc-swap = "1" +notify = "8" [features] bundled-z3 = ["openshell-prover/bundled-z3"] diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 676e23071..b7304f899 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -413,20 +413,28 @@ pub async fn run_server( info!("Metrics server disabled"); } + let (shutdown_tx, shutdown_rx) = watch::channel(false); + // Build TLS acceptor when TLS is configured; otherwise serve plaintext. let tls_acceptor = if let Some(tls) = &config.tls { - Some(TlsAcceptor::from_files( + let acceptor = TlsAcceptor::from_files( &tls.cert_path, &tls.key_path, tls.client_ca_path.as_deref(), tls.require_client_auth, - )?) + )?; + + // Spawn file-watcher-based TLS certificate reload worker. + // Watches parent directories of cert/key/CA files and atomically + // reloads when changes are detected. + acceptor.spawn_reload_worker(shutdown_rx.clone()); + + Some(acceptor) } else { info!("TLS disabled — accepting plaintext connections"); None }; - let (shutdown_tx, shutdown_rx) = watch::channel(false); let mut listener_tasks = Vec::with_capacity(gateway_listeners.len()); let enable_loopback_service_http = config.service_routing.enable_loopback_service_http; for (listener, listen_addr) in gateway_listeners { @@ -615,7 +623,10 @@ fn spawn_gateway_connection( warn!(client = %addr, listen = %listen_addr, "Rejected plaintext HTTP on non-loopback gateway listener"); } Ok(ConnectionProtocol::Tls | ConnectionProtocol::Unknown) => { - match acceptor.inner().accept(stream).await { + // acceptor.acceptor() snapshots the current TLS config; + // the returned acceptor owns an Arc that stays alive for + // the full duration of the handshake. + match acceptor.acceptor().accept(stream).await { Ok(tls_stream) => { let peer_identity = multiplex::extract_peer_identity(&tls_stream); if let Err(e) = service diff --git a/crates/openshell-server/src/tls.rs b/crates/openshell-server/src/tls.rs index 1af1ce0cd..8da7f4936 100644 --- a/crates/openshell-server/src/tls.rs +++ b/crates/openshell-server/src/tls.rs @@ -3,19 +3,42 @@ //! TLS support using tokio-rustls. +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use arc_swap::ArcSwap; +use notify::event::EventKind; +use notify::{Event, RecursiveMode, Watcher}; use openshell_core::{Error, Result}; +use openshell_ocsf::{ + ConfigStateChangeBuilder, OCSF_TARGET, SandboxContext, SeverityId, StateId, StatusId, +}; use rustls::ServerConfig; +use rustls::crypto::ring::sign; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::server::WebPkiClientVerifier; -use std::fs::File; -use std::io::BufReader; -use std::path::Path; -use std::sync::Arc; +use tokio::sync::{mpsc, watch}; +use tracing::{debug, info, warn}; /// TLS acceptor for wrapping connections. +/// +/// Uses `ArcSwap` internally so the active `ServerConfig` can be atomically +/// swapped by a background reload worker without blocking TLS handshakes. +/// +/// Stores the cert/key/CA paths from construction so that `reload()` can +/// re-read from the same files without the caller tracking them separately. #[derive(Clone)] pub struct TlsAcceptor { - acceptor: tokio_rustls::TlsAcceptor, + config: Arc>, + cert_path: PathBuf, + key_path: PathBuf, + client_ca_path: Option, + require_client_auth: bool, + reload_spawned: Arc, } impl TlsAcceptor { @@ -42,53 +65,234 @@ impl TlsAcceptor { client_ca_path: Option<&Path>, require_client_auth: bool, ) -> Result { - let certs = load_certs(cert_path)?; - let key = load_key(key_path)?; - - let mut config = if let Some(ca_path) = client_ca_path { - let ca_certs = load_certs(ca_path)?; - let mut root_store = rustls::RootCertStore::empty(); - for cert in ca_certs { - root_store - .add(cert) - .map_err(|e| Error::tls(format!("failed to add CA certificate: {e}")))?; + let config = build_server_config(cert_path, key_path, client_ca_path, require_client_auth)?; + Ok(Self { + config: Arc::new(ArcSwap::from(config)), + cert_path: cert_path.to_path_buf(), + key_path: key_path.to_path_buf(), + client_ca_path: client_ca_path.map(Path::to_path_buf), + require_client_auth, + reload_spawned: Arc::new(AtomicBool::new(false)), + }) + } + + /// Re-read certificates from the same paths used at construction and + /// atomically swap the active config. + /// + /// Returns `Ok(())` when the new config was built and swapped successfully. + /// Returns `Err(...)` if cert/key loading fails — the old config is preserved. + pub fn reload(&self) -> Result<()> { + let new_config = build_server_config( + &self.cert_path, + &self.key_path, + self.client_ca_path.as_deref(), + self.require_client_auth, + )?; + self.config.store(new_config); + + let event = ConfigStateChangeBuilder::new(&tls_ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "reloaded") + .message("TLS certificate config reloaded successfully") + .build(); + info!( + target: OCSF_TARGET, + sandbox_id = "", + message = %event.format_shorthand() + ); + + Ok(()) + } + + /// Return a fresh `tokio_rustls::TlsAcceptor` backed by the current config + /// snapshot. Each call clones the active `Arc` so it remains + /// alive for the duration of the TLS handshake. + #[must_use] + pub fn acceptor(&self) -> tokio_rustls::TlsAcceptor { + tokio_rustls::TlsAcceptor::from(self.config.load_full()) + } + + /// Spawn a background worker that watches the parent directories of the + /// cert, key, and CA files and calls [`reload`](Self::reload) when changes + /// are detected. + /// + /// A 1-second debounce window coalesces rapid filesystem events (such as + /// Kubernetes Secret volume atomic swaps) into a single reload. If reload + /// fails, the old config is preserved and a warning is logged. + /// + /// The worker exits when the `shutdown` watch channel fires, allowing the + /// gateway to perform a graceful shutdown without orphaned reload + /// tasks. + pub fn spawn_reload_worker( + &self, + mut shutdown: watch::Receiver, + ) -> tokio::task::JoinHandle<()> { + if self.reload_spawned.swap(true, Ordering::Relaxed) { + warn!("TLS certificate reload worker already spawned, ignoring duplicate call"); + return tokio::spawn(async {}); + } + + let this = self.clone(); + + // Collect unique parent directories to watch. + let cert_dir = self.cert_path.parent().unwrap_or_else(|| Path::new(".")); + let key_dir = self.key_path.parent().unwrap_or_else(|| Path::new(".")); + let mut dirs = vec![cert_dir.to_path_buf()]; + if key_dir != cert_dir { + dirs.push(key_dir.to_path_buf()); + } + if let Some(ref ca) = self.client_ca_path { + let ca_dir = ca.parent().unwrap_or_else(|| Path::new(".")); + if ca_dir != cert_dir && ca_dir != key_dir { + dirs.push(ca_dir.to_path_buf()); } + } + + let debounce = Duration::from_secs(1); + + tokio::spawn(async move { + let (tx, mut rx) = mpsc::unbounded_channel(); - let verifier_builder = WebPkiClientVerifier::builder(Arc::new(root_store)); - let verifier = if require_client_auth { - verifier_builder - } else { - verifier_builder.allow_unauthenticated() + // recommended_watcher runs its own thread; we bridge events into + // the tokio runtime via the unbounded mpsc channel. + let mut watcher = match notify::recommended_watcher( + move |res: std::result::Result| { + if let Ok(event) = res + && matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) + { + let _ = tx.send(()); + } + }, + ) { + Ok(w) => w, + Err(e) => { + warn!(error = %e, "Failed to start TLS cert file watcher, hot-reload disabled"); + return; + } + }; + + for dir in &dirs { + if let Err(e) = watcher.watch(dir, RecursiveMode::NonRecursive) { + warn!(error = %e, dir = %dir.display(), "Failed to watch TLS cert directory, hot-reload disabled"); + return; + } } - .build() - .map_err(|e| Error::tls(format!("failed to build client verifier: {e}")))?; - ServerConfig::builder() - .with_client_cert_verifier(verifier) - .with_single_cert(certs, key) - .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? - } else { - ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? - }; + info!(?dirs, "TLS certificate file watcher started"); - config - .alpn_protocols - .extend([b"h2".to_vec(), b"http/1.1".to_vec()]); + // Event loop with manual debounce. + // When the watcher fires, we drain any follow-up events that + // arrive within the debounce window before calling reload(). + // This handles Kubernetes Secret atomic-swap patterns where + // kubelet writes a new ..data directory and swaps symlinks in + // rapid succession. - Ok(Self { - acceptor: tokio_rustls::TlsAcceptor::from(Arc::new(config)), + 'outer: loop { + // Wait for the first event (or shutdown). + let got_event = tokio::select! { + r = rx.recv() => r.is_some(), + _ = shutdown.changed() => break 'outer, + }; + + if !got_event { + warn!("TLS cert file watcher disconnected, hot-reload stopping"); + break 'outer; + } + + // Debounce: keep draining events for the debounce duration. + loop { + tokio::select! { + () = tokio::time::sleep(debounce) => { + // Debounce window elapsed — reload now. + if let Err(e) = this.reload() { + let event = ConfigStateChangeBuilder::new(&tls_ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Enabled, "reload_failed") + .message(format!( + "TLS certificate reload failed: {e}" + )) + .build(); + info!( + target: OCSF_TARGET, + sandbox_id = "", + message = %event.format_shorthand() + ); + warn!(error = %e, "TLS certificate reload failed, keeping existing config"); + } + break; + } + r = rx.recv() => { + if r.is_some() { + // Another event arrived — reset debounce. + continue; + } + warn!("TLS cert file watcher disconnected, hot-reload stopping"); + break 'outer; + } + _ = shutdown.changed() => { + debug!("TLS certificate reload worker stopped"); + break 'outer; + } + } + } + } + + // `watcher` dropped here — stops the underlying watcher thread. }) } +} - /// Get the inner tokio-rustls acceptor. - #[must_use] - #[allow(clippy::missing_const_for_fn)] - pub fn inner(&self) -> &tokio_rustls::TlsAcceptor { - &self.acceptor - } +/// Build a `ServerConfig` from certificate, key, and optional client CA files. +fn build_server_config( + cert_path: &Path, + key_path: &Path, + client_ca_path: Option<&Path>, + require_client_auth: bool, +) -> Result> { + let certs = load_certs(cert_path)?; + let key = load_key(key_path)?; + + // Validate the key type early — rustls defers this to handshake time, + // which produces a cryptic error. A bad key type surfaces clearly here. + sign::any_supported_type(&key) + .map_err(|e| Error::tls(format!("unsupported private key type: {e}")))?; + + let mut config = if let Some(ca_path) = client_ca_path { + let ca_certs = load_certs(ca_path)?; + let mut root_store = rustls::RootCertStore::empty(); + for cert in ca_certs { + root_store + .add(cert) + .map_err(|e| Error::tls(format!("failed to add CA certificate: {e}")))?; + } + + let verifier_builder = WebPkiClientVerifier::builder(Arc::new(root_store)); + let verifier = if require_client_auth { + verifier_builder + } else { + verifier_builder.allow_unauthenticated() + } + .build() + .map_err(|e| Error::tls(format!("failed to build client verifier: {e}")))?; + + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(certs, key) + .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::tls(format!("failed to create TLS config: {e}")))? + }; + + config + .alpn_protocols + .extend([b"h2".to_vec(), b"http/1.1".to_vec()]); + + Ok(Arc::new(config)) } /// Load certificates from a PEM file. @@ -128,3 +332,621 @@ fn load_key(path: &Path) -> Result> { Err(Error::tls("no private key found in file")) } + +/// Build an OCSF context for gateway-level (non-sandbox) events. +fn tls_ocsf_ctx() -> SandboxContext { + SandboxContext { + sandbox_id: String::new(), + sandbox_name: String::new(), + container_image: "openshell/gateway".to_string(), + hostname: "openshell-gateway".to_string(), + product_version: openshell_core::VERSION.to_string(), + proxy_ip: std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), + proxy_port: 0, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rcgen::{CertificateParams, IsCa, KeyPair, KeyUsagePurpose}; + use std::io::Write; + use tokio::net::{TcpListener, TcpStream}; + + fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); + } + + /// Generate test CA + server certs in `dir`, returning the CA cert and + /// keypair so callers can sign additional server or client certificates. + fn generate_test_certs_with_ca(dir: &Path) -> (rcgen::Certificate, KeyPair) { + let mut ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let ca_key = KeyPair::generate().expect("failed to generate CA key"); + let ca_cert = ca_params + .self_signed(&ca_key) + .expect("failed to sign CA cert"); + + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &ca_cert, &ca_key) + .expect("failed to sign server cert"); + + let write_file = |name: &str, data: &[u8]| { + let path = dir.join(name); + File::create(&path) + .and_then(|mut file| file.write_all(data)) + .expect("failed to write test file"); + }; + write_file("ca.pem", ca_cert.pem().as_bytes()); + write_file("server-cert.pem", server_cert.pem().as_bytes()); + write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + + (ca_cert, ca_key) + } + + /// Generate a new server cert + key in `dir`, signed by the given CA. + /// Overwrites `server-cert.pem` and `server-key.pem`. + fn generate_server_cert(ca_cert: &rcgen::Certificate, ca_key: &KeyPair, dir: &Path) { + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, ca_cert, ca_key) + .expect("failed to sign server cert"); + + let write_file = |name: &str, data: &[u8]| { + let path = dir.join(name); + File::create(&path) + .and_then(|mut file| file.write_all(data)) + .expect("failed to write test file"); + }; + write_file("server-cert.pem", server_cert.pem().as_bytes()); + write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + } + + fn build_test_client_config(ca_path: &Path) -> Arc { + let ca_certs = load_certs(ca_path).expect("failed to load CA certs"); + let mut root_store = rustls::RootCertStore::empty(); + for cert in ca_certs { + root_store + .add(cert) + .expect("failed to add CA to root store"); + } + Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) + } + + #[test] + fn test_build_server_config() { + install_rustls_provider(); + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let _ = generate_test_certs_with_ca(dir.path()); + + let config = build_server_config( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build server config"); + + assert!(config.alpn_protocols.contains(&b"h2".to_vec())); + assert!(config.alpn_protocols.contains(&b"http/1.1".to_vec())); + } + + #[test] + fn test_reload_success() { + install_rustls_provider(); + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let (ca_cert, ca_key) = generate_test_certs_with_ca(dir.path()); + + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + generate_server_cert(&ca_cert, &ca_key, dir.path()); + acceptor.reload().expect("reload should succeed"); + } + + #[test] + fn test_reload_invalid_preserves_old() { + install_rustls_provider(); + let dir = tempfile::tempdir().expect("failed to create tempdir"); + generate_test_certs_with_ca(dir.path()); + + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + // Snapshot the current config before corrupting files. + let acceptor_before = acceptor.acceptor(); + + std::fs::write(dir.path().join("server-cert.pem"), b"garbage") + .expect("failed to write garbage"); + + assert!(acceptor.reload().is_err(), "reload with garbage cert should fail"); + + // Old config must still be accessible after a failed reload. + drop(acceptor_before); + let _ = acceptor.acceptor(); // does not panic + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_concurrent_handshake_and_reload() { + install_rustls_provider(); + + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let (ca_cert, ca_key) = generate_test_certs_with_ca(dir.path()); + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + let client_config = build_test_client_config(&dir.path().join("ca.pem")); + let connector = tokio_rustls::TlsConnector::from(client_config); + let server_name = + rustls::pki_types::ServerName::try_from("localhost").expect("invalid server name"); + + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let listen_addr = listener.local_addr().expect("failed to get local addr"); + + let acceptor_for_server = acceptor.clone(); + let (server_done_tx, mut server_done_rx) = watch::channel(false); + let server_handle = tokio::spawn(async move { + loop { + tokio::select! { + result = listener.accept() => { + match result { + Ok((stream, _)) => { + let acc = acceptor_for_server.clone(); + tokio::spawn(async move { + let _ = acc.acceptor().accept(stream).await; + }); + } + Err(_) => break, + } + } + _ = server_done_rx.changed() => break, + } + } + }); + + let num_clients = 10; + let cycles_per_client = 5; + let mut client_handles = Vec::with_capacity(num_clients); + + for i in 0..num_clients { + let connector = connector.clone(); + let name = server_name.clone(); + let addr = listen_addr; + client_handles.push(tokio::spawn(async move { + for _ in 0..cycles_per_client { + let tcp = TcpStream::connect(addr) + .await + .expect("client connect failed"); + let _tls = connector + .connect(name.clone(), tcp) + .await + .expect("client TLS handshake failed"); + } + i + })); + } + + let reload_handle = tokio::spawn(async move { + for _ in 0..20 { + generate_server_cert(&ca_cert, &ca_key, dir.path()); + acceptor.reload().expect("reload with valid cert should succeed"); + tokio::time::sleep(Duration::from_millis(5)).await; + } + }); + + let mut client_failures = 0; + for handle in client_handles { + match handle.await { + Err(e) if e.is_panic() => client_failures += 1, + _ => {} + } + } + assert_eq!(client_failures, 0, "some client tasks panicked"); + + reload_handle.await.expect("reload task panicked"); + + let _ = server_done_tx.send(true); + let _ = tokio::time::timeout(Duration::from_secs(2), server_handle).await; + } + + #[tokio::test] + async fn test_reload_serves_new_cert() { + // Helper to extract the DER-encoded server certificate from a TLS session. + fn peer_cert_der(tls: &tokio_rustls::client::TlsStream) -> Vec { + tls.get_ref().1.peer_certificates().expect("no peer certs")[0] + .as_ref() + .to_vec() + } + + install_rustls_provider(); + + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let (ca_cert, ca_key) = generate_test_certs_with_ca(dir.path()); + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + let client_config = build_test_client_config(&dir.path().join("ca.pem")); + let connector = tokio_rustls::TlsConnector::from(client_config); + let server_name = + rustls::pki_types::ServerName::try_from("localhost").expect("invalid server name"); + + // Connection 1: original cert + let listener1 = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let addr1 = listener1.local_addr().expect("failed to get local addr"); + let acc1 = acceptor.clone(); + let server_task_1 = tokio::spawn(async move { + let (stream, _) = listener1.accept().await.expect("accept failed"); + acc1.acceptor() + .accept(stream) + .await + .expect("TLS accept failed"); + }); + let tcp_1 = TcpStream::connect(addr1) + .await + .expect("connect failed"); + let tls_1 = connector + .connect(server_name.clone(), tcp_1) + .await + .expect("TLS handshake failed"); + let server_cert_1 = peer_cert_der(&tls_1); + server_task_1.await.expect("server task 1 failed"); + assert!(!server_cert_1.is_empty()); + + // Generate new server cert + reload + generate_server_cert(&ca_cert, &ca_key, dir.path()); + acceptor.reload().expect("reload should succeed"); + + // Connection 2: new cert on a fresh listener + let listener2 = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let addr2 = listener2.local_addr().expect("failed to get local addr"); + let acc2 = acceptor.clone(); + let server_task_2 = tokio::spawn(async move { + let (stream, _) = listener2.accept().await.expect("accept failed"); + acc2.acceptor() + .accept(stream) + .await + .expect("TLS accept failed"); + }); + let tcp_2 = TcpStream::connect(addr2) + .await + .expect("connect failed"); + let tls_2 = connector + .connect(server_name.clone(), tcp_2) + .await + .expect("TLS handshake failed"); + let server_cert_2 = peer_cert_der(&tls_2); + server_task_2.await.expect("server task 2 failed"); + assert!(!server_cert_2.is_empty()); + + assert_ne!( + server_cert_1, server_cert_2, + "served cert should change after reload" + ); + } + + #[tokio::test] + async fn test_reload_worker_shutdown() { + install_rustls_provider(); + + let dir = tempfile::tempdir().expect("failed to create tempdir"); + generate_test_certs_with_ca(dir.path()); + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = acceptor.spawn_reload_worker(shutdown_rx); + + // Give the watcher time to start + tokio::time::sleep(Duration::from_millis(200)).await; + + // Send shutdown signal and verify the worker exits promptly. + // shutdown.changed() is checked in both the outer select (waiting + // for first event) and the inner debounce loop — no file change is + // needed to exercise the shutdown path. + let _ = shutdown_tx.send(true); + tokio::time::timeout(Duration::from_secs(2), handle) + .await + .expect("reload worker should exit after shutdown signal") + .expect("reload worker should not panic"); + } + + #[tokio::test] + async fn test_reload_worker_detects_file_change() { + install_rustls_provider(); + + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let (ca_cert, ca_key) = generate_test_certs_with_ca(dir.path()); + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + false, + ) + .expect("failed to build acceptor"); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = acceptor.spawn_reload_worker(shutdown_rx); + + // Give the watcher time to start + tokio::time::sleep(Duration::from_millis(200)).await; + + // Generate new server cert — the file watcher should detect this + generate_server_cert(&ca_cert, &ca_key, dir.path()); + + // Retry TLS handshake until the watcher picks up the new cert or + // the deadline expires. Follows the wait_for_status pattern + // (health_endpoint_integration.rs). + let client_config = build_test_client_config(&dir.path().join("ca.pem")); + let connector = tokio_rustls::TlsConnector::from(client_config); + let server_name = + rustls::pki_types::ServerName::try_from("localhost").expect("invalid server name"); + + let deadline = tokio::time::Instant::now() + Duration::from_secs(10); + loop { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let listen_addr = listener.local_addr().expect("failed to get local addr"); + + let acc = acceptor.clone(); + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept failed"); + acc.acceptor().accept(stream).await + }); + + let tcp = TcpStream::connect(listen_addr) + .await + .expect("connect failed"); + let handshake_result = tokio::time::timeout( + Duration::from_millis(500), + connector.connect(server_name.clone(), tcp), + ) + .await; + + if let Ok(Ok(_tls)) = handshake_result { + server_task + .await + .expect("server task failed") + .expect("TLS accept should succeed after watcher reload"); + break; + } + + assert!( + tokio::time::Instant::now() < deadline, + "watcher did not detect cert change within 10s" + ); + tokio::time::sleep(Duration::from_millis(200)).await; + } + + // Clean shutdown + let _ = shutdown_tx.send(true); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + } + + #[tokio::test] + async fn test_reload_mtls_ca_rotation() { + install_rustls_provider(); + + let dir = tempfile::tempdir().expect("failed to create tempdir"); + let (initial_ca_cert, initial_ca_key) = generate_test_certs_with_ca(dir.path()); + + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + Some(&dir.path().join("ca.pem")), + true, // require mTLS + ) + .expect("failed to build acceptor with mTLS"); + + // Generate new CA and overwrite ca.pem + let mut new_ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + new_ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + new_ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "new-ca"); + let new_ca_key = KeyPair::generate().expect("failed to generate new CA key"); + let new_ca_cert = new_ca_params + .self_signed(&new_ca_key) + .expect("failed to sign new CA cert"); + std::fs::write( + dir.path().join("ca.pem"), + new_ca_cert.pem().as_bytes(), + ) + .expect("failed to write new CA"); + + // Generate new server cert signed by new CA + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &new_ca_cert, &new_ca_key) + .expect("failed to sign server cert"); + std::fs::write( + dir.path().join("server-cert.pem"), + server_cert.pem().as_bytes(), + ) + .expect("failed to write server cert"); + std::fs::write( + dir.path().join("server-key.pem"), + server_key.serialize_pem().as_bytes(), + ) + .expect("failed to write server key"); + + acceptor.reload().expect("reload with new CA should succeed"); + + // Generate client cert signed by new CA, write to files + let client_key = KeyPair::generate().expect("failed to generate client key"); + let mut client_params = CertificateParams::new(Vec::::new()) + .expect("failed to create client params"); + client_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-client"); + client_params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + let client_cert = client_params + .signed_by(&client_key, &new_ca_cert, &new_ca_key) + .expect("failed to sign client cert"); + + // Write client cert + key as PEM files and load via load_certs/load_key + let client_cert_path = dir.path().join("client-cert.pem"); + let client_key_path = dir.path().join("client-key.pem"); + std::fs::write(&client_cert_path, client_cert.pem().as_bytes()) + .expect("failed to write client cert"); + std::fs::write(&client_key_path, client_key.serialize_pem().as_bytes()) + .expect("failed to write client key"); + + let client_cert_chain = load_certs(&client_cert_path).expect("failed to load client cert"); + let client_private_key = load_key(&client_key_path).expect("failed to load client key"); + + let mut root_store = rustls::RootCertStore::empty(); + root_store + .add(CertificateDer::from(new_ca_cert.der().to_vec())) + .expect("failed to add new CA to root store"); + let new_ca_client_config = Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(client_cert_chain, client_private_key) + .expect("failed to set client auth cert"), + ); + + // Verify a real handshake succeeds with the new CA + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let listen_addr = listener.local_addr().expect("failed to get local addr"); + + let acceptor_srv = acceptor.clone(); + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept failed"); + acceptor_srv + .acceptor() + .accept(stream) + .await + .expect("TLS accept with new mTLS CA failed"); + }); + + let connector = tokio_rustls::TlsConnector::from(new_ca_client_config); + let server_name = + rustls::pki_types::ServerName::try_from("localhost").expect("invalid server name"); + let tcp = TcpStream::connect(listen_addr) + .await + .expect("connect failed"); + connector + .connect(server_name.clone(), tcp) + .await + .expect("mTLS handshake with new CA should succeed"); + + server_task.await.expect("server task failed"); + + // Verify old CA is no longer trusted: a client cert signed by the + // initial CA should be rejected after rotation. + let old_client_key = KeyPair::generate().expect("failed to generate old client key"); + let mut old_client_params = CertificateParams::new(Vec::::new()) + .expect("failed to create old client params"); + old_client_params + .distinguished_name + .push(rcgen::DnType::CommonName, "old-client"); + old_client_params.key_usages = vec![ + KeyUsagePurpose::DigitalSignature, + KeyUsagePurpose::KeyEncipherment, + ]; + let old_client_cert = old_client_params + .signed_by(&old_client_key, &initial_ca_cert, &initial_ca_key) + .expect("failed to sign old client cert"); + + let old_cert_path = dir.path().join("old-client-cert.pem"); + let old_key_path = dir.path().join("old-client-key.pem"); + std::fs::write(&old_cert_path, old_client_cert.pem().as_bytes()) + .expect("failed to write old client cert"); + std::fs::write(&old_key_path, old_client_key.serialize_pem().as_bytes()) + .expect("failed to write old client key"); + + let old_cert_chain = load_certs(&old_cert_path).expect("failed to load old client cert"); + let old_key_der = load_key(&old_key_path).expect("failed to load old client key"); + + let mut old_root_store = rustls::RootCertStore::empty(); + old_root_store + .add(CertificateDer::from(new_ca_cert.der().to_vec())) + .expect("failed to add new CA to root store"); + let old_ca_client_config = Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(old_root_store) + .with_client_auth_cert(old_cert_chain, old_key_der) + .expect("failed to set old client auth cert"), + ); + + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let listen_addr = listener.local_addr().expect("failed to get local addr"); + + let acceptor_srv = acceptor.clone(); + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept failed"); + let result = acceptor_srv.acceptor().accept(stream).await; + assert!( + result.is_err(), + "TLS accept should reject client cert signed by old CA" + ); + }); + + // Drive the client side so the server has something to accept. + // tokio_rustls::TlsConnector::connect() may return Ok even when the + // server rejects the client cert (TLS 1.3 post-handshake auth), + // so the authoritative check is the server-side accept result above. + let connector = tokio_rustls::TlsConnector::from(old_ca_client_config); + let tcp = TcpStream::connect(listen_addr) + .await + .expect("connect failed"); + let _ = connector.connect(server_name, tcp).await; + + server_task.await.expect("server task failed"); + } +} diff --git a/crates/openshell-server/tests/common/mod.rs b/crates/openshell-server/tests/common/mod.rs index 00228b043..3077cf4c9 100644 --- a/crates/openshell-server/tests/common/mod.rs +++ b/crates/openshell-server/tests/common/mod.rs @@ -566,7 +566,7 @@ pub async fn start_test_server( let svc = service.clone(); let tls = tls_acceptor.clone(); tokio::spawn(async move { - let Ok(tls_stream) = tls.inner().accept(stream).await else { + let Ok(tls_stream) = tls.acceptor().accept(stream).await else { return; }; let _ = Builder::new(TokioExecutor::new()) From 7d17a9819f52855a8e9c2c8e27b062ca7a60eb45 Mon Sep 17 00:00:00 2001 From: Yuedong Wu Date: Thu, 11 Jun 2026 19:29:51 +0800 Subject: [PATCH 2/2] refactor(server): extract shared TLS test utilities Signed-off-by: Yuedong Wu --- crates/openshell-server/src/lib.rs | 37 +------ crates/openshell-server/src/tls.rs | 104 ++++-------------- crates/openshell-server/src/tls_test_utils.rs | 62 +++++++++++ 3 files changed, 90 insertions(+), 113 deletions(-) create mode 100644 crates/openshell-server/src/tls_test_utils.rs diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index b7304f899..c13a0d462 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -40,6 +40,8 @@ mod ssh_sessions; pub mod supervisor_session; mod telemetry; mod tls; +#[cfg(test)] +pub(crate) mod tls_test_utils; pub mod tracing_bus; mod ws_tunnel; @@ -919,8 +921,7 @@ mod tests { ComputeDriverKind, Config, proto::{HealthRequest, open_shell_client::OpenShellClient}, }; - use rcgen::{CertificateParams, IsCa, KeyPair}; - use std::io::{Error, ErrorKind, Write}; + use std::io::{Error, ErrorKind}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -929,41 +930,13 @@ mod tests { use tokio::net::{TcpListener, TcpStream}; use tokio::sync::watch; - fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); - } + use crate::tls_test_utils::{generate_test_certs_with_ca, install_rustls_provider}; fn test_tls_acceptor() -> (TempDir, TlsAcceptor) { install_rustls_provider(); - let mut ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create CA params"); - ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-ca"); - let ca_key = KeyPair::generate().expect("failed to generate CA key"); - let ca_cert = ca_params - .self_signed(&ca_key) - .expect("failed to sign CA cert"); - - let server_params = CertificateParams::new(vec!["localhost".to_string()]) - .expect("failed to create server params"); - let server_key = KeyPair::generate().expect("failed to generate server key"); - let server_cert = server_params - .signed_by(&server_key, &ca_cert, &ca_key) - .expect("failed to sign server cert"); - let dir = tempdir().expect("failed to create tempdir"); - let write_file = |name: &str, data: &[u8]| { - let path = dir.path().join(name); - std::fs::File::create(&path) - .and_then(|mut file| file.write_all(data)) - .expect("failed to write tls test file"); - }; - write_file("ca.pem", ca_cert.pem().as_bytes()); - write_file("server-cert.pem", server_cert.pem().as_bytes()); - write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + generate_test_certs_with_ca(dir.path()); let acceptor = TlsAcceptor::from_files( &dir.path().join("server-cert.pem"), diff --git a/crates/openshell-server/src/tls.rs b/crates/openshell-server/src/tls.rs index 8da7f4936..c8eb0e5c9 100644 --- a/crates/openshell-server/src/tls.rs +++ b/crates/openshell-server/src/tls.rs @@ -349,48 +349,12 @@ fn tls_ocsf_ctx() -> SandboxContext { #[cfg(test)] mod tests { use super::*; + use crate::tls_test_utils::{ + generate_test_certs_with_ca, install_rustls_provider, write_test_file, + }; use rcgen::{CertificateParams, IsCa, KeyPair, KeyUsagePurpose}; - use std::io::Write; use tokio::net::{TcpListener, TcpStream}; - fn install_rustls_provider() { - let _ = rustls::crypto::ring::default_provider().install_default(); - } - - /// Generate test CA + server certs in `dir`, returning the CA cert and - /// keypair so callers can sign additional server or client certificates. - fn generate_test_certs_with_ca(dir: &Path) -> (rcgen::Certificate, KeyPair) { - let mut ca_params = - CertificateParams::new(Vec::::new()).expect("failed to create CA params"); - ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); - ca_params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-ca"); - let ca_key = KeyPair::generate().expect("failed to generate CA key"); - let ca_cert = ca_params - .self_signed(&ca_key) - .expect("failed to sign CA cert"); - - let server_params = CertificateParams::new(vec!["localhost".to_string()]) - .expect("failed to create server params"); - let server_key = KeyPair::generate().expect("failed to generate server key"); - let server_cert = server_params - .signed_by(&server_key, &ca_cert, &ca_key) - .expect("failed to sign server cert"); - - let write_file = |name: &str, data: &[u8]| { - let path = dir.join(name); - File::create(&path) - .and_then(|mut file| file.write_all(data)) - .expect("failed to write test file"); - }; - write_file("ca.pem", ca_cert.pem().as_bytes()); - write_file("server-cert.pem", server_cert.pem().as_bytes()); - write_file("server-key.pem", server_key.serialize_pem().as_bytes()); - - (ca_cert, ca_key) - } - /// Generate a new server cert + key in `dir`, signed by the given CA. /// Overwrites `server-cert.pem` and `server-key.pem`. fn generate_server_cert(ca_cert: &rcgen::Certificate, ca_key: &KeyPair, dir: &Path) { @@ -401,14 +365,8 @@ mod tests { .signed_by(&server_key, ca_cert, ca_key) .expect("failed to sign server cert"); - let write_file = |name: &str, data: &[u8]| { - let path = dir.join(name); - File::create(&path) - .and_then(|mut file| file.write_all(data)) - .expect("failed to write test file"); - }; - write_file("server-cert.pem", server_cert.pem().as_bytes()); - write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + write_test_file(dir, "server-cert.pem", server_cert.pem().as_bytes()); + write_test_file(dir, "server-key.pem", server_key.serialize_pem().as_bytes()); } fn build_test_client_config(ca_path: &Path) -> Arc { @@ -482,7 +440,10 @@ mod tests { std::fs::write(dir.path().join("server-cert.pem"), b"garbage") .expect("failed to write garbage"); - assert!(acceptor.reload().is_err(), "reload with garbage cert should fail"); + assert!( + acceptor.reload().is_err(), + "reload with garbage cert should fail" + ); // Old config must still be accessible after a failed reload. drop(acceptor_before); @@ -559,7 +520,9 @@ mod tests { let reload_handle = tokio::spawn(async move { for _ in 0..20 { generate_server_cert(&ca_cert, &ca_key, dir.path()); - acceptor.reload().expect("reload with valid cert should succeed"); + acceptor + .reload() + .expect("reload with valid cert should succeed"); tokio::time::sleep(Duration::from_millis(5)).await; } }); @@ -618,9 +581,7 @@ mod tests { .await .expect("TLS accept failed"); }); - let tcp_1 = TcpStream::connect(addr1) - .await - .expect("connect failed"); + let tcp_1 = TcpStream::connect(addr1).await.expect("connect failed"); let tls_1 = connector .connect(server_name.clone(), tcp_1) .await @@ -646,9 +607,7 @@ mod tests { .await .expect("TLS accept failed"); }); - let tcp_2 = TcpStream::connect(addr2) - .await - .expect("connect failed"); + let tcp_2 = TcpStream::connect(addr2).await.expect("connect failed"); let tls_2 = connector .connect(server_name.clone(), tcp_2) .await @@ -793,36 +752,19 @@ mod tests { let new_ca_cert = new_ca_params .self_signed(&new_ca_key) .expect("failed to sign new CA cert"); - std::fs::write( - dir.path().join("ca.pem"), - new_ca_cert.pem().as_bytes(), - ) - .expect("failed to write new CA"); - - // Generate new server cert signed by new CA - let server_params = CertificateParams::new(vec!["localhost".to_string()]) - .expect("failed to create server params"); - let server_key = KeyPair::generate().expect("failed to generate server key"); - let server_cert = server_params - .signed_by(&server_key, &new_ca_cert, &new_ca_key) - .expect("failed to sign server cert"); - std::fs::write( - dir.path().join("server-cert.pem"), - server_cert.pem().as_bytes(), - ) - .expect("failed to write server cert"); - std::fs::write( - dir.path().join("server-key.pem"), - server_key.serialize_pem().as_bytes(), - ) - .expect("failed to write server key"); + std::fs::write(dir.path().join("ca.pem"), new_ca_cert.pem().as_bytes()) + .expect("failed to write new CA"); - acceptor.reload().expect("reload with new CA should succeed"); + // Generate new server cert signed by new CA and reload + generate_server_cert(&new_ca_cert, &new_ca_key, dir.path()); + acceptor + .reload() + .expect("reload with new CA should succeed"); // Generate client cert signed by new CA, write to files let client_key = KeyPair::generate().expect("failed to generate client key"); - let mut client_params = CertificateParams::new(Vec::::new()) - .expect("failed to create client params"); + let mut client_params = + CertificateParams::new(Vec::::new()).expect("failed to create client params"); client_params .distinguished_name .push(rcgen::DnType::CommonName, "test-client"); diff --git a/crates/openshell-server/src/tls_test_utils.rs b/crates/openshell-server/src/tls_test_utils.rs new file mode 100644 index 000000000..aee83c49e --- /dev/null +++ b/crates/openshell-server/src/tls_test_utils.rs @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared test helpers for TLS-related tests. + +use std::fs::File; +use std::io::Write; +use std::path::Path; + +use rcgen::{CertificateParams, IsCa, KeyPair}; + +/// Install the default rustls crypto provider. +/// +/// Must be called once at the start of any test that exercises TLS handshakes. +/// Multiple calls are harmless (subsequent calls return an error, ignored). +pub fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +/// Write bytes to a file inside `dir`, panicking on failure. +pub fn write_test_file(dir: &Path, name: &str, data: &[u8]) { + let path = dir.join(name); + File::create(&path) + .and_then(|mut file| file.write_all(data)) + .expect("failed to write test file"); +} + +/// Generate a self-signed CA certificate and a `localhost` server certificate, +/// writing them as PEM files into `dir`. +/// +/// Returns the CA certificate and keypair so callers can sign additional +/// server or client certificates. +/// +/// Files written: +/// - `ca.pem` +/// - `server-cert.pem` +/// - `server-key.pem` +pub fn generate_test_certs_with_ca(dir: &Path) -> (rcgen::Certificate, KeyPair) { + let mut ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let ca_key = KeyPair::generate().expect("failed to generate CA key"); + let ca_cert = ca_params + .self_signed(&ca_key) + .expect("failed to sign CA cert"); + + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &ca_cert, &ca_key) + .expect("failed to sign server cert"); + + write_test_file(dir, "ca.pem", ca_cert.pem().as_bytes()); + write_test_file(dir, "server-cert.pem", server_cert.pem().as_bytes()); + write_test_file(dir, "server-key.pem", server_key.serialize_pem().as_bytes()); + + (ca_cert, ca_key) +}