diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index ec0efded6..10ef69838 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -92,7 +92,10 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. +runtime still owns GPU device injection. GPU requests are explicit, and can be +refined with a driver-native device identifier or requested count; the gateway +validates the request shape and each runtime enforces the GPU allocation modes it +supports. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index ea0dd79ca..01b57ec97 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -19,6 +19,7 @@ use openshell_bootstrap::{ use openshell_cli::completers; use openshell_cli::run; use openshell_cli::tls::TlsOptions; +use openshell_core::proto::GpuResourceRequirements; /// Resolved gateway context: name + gateway endpoint. struct GatewayContext { @@ -28,6 +29,12 @@ struct GatewayContext { endpoint: String, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum GpuCliRequest { + DriverDefault, + Count(u32), +} + /// Resolve the gateway name to a [`GatewayContext`] with the gateway endpoint. /// /// Resolution priority: @@ -109,6 +116,29 @@ fn resolve_gateway( }) } +fn resolve_gpu_requirements(gpu: Option) -> Option { + match gpu { + Some(GpuCliRequest::Count(count)) => Some(GpuResourceRequirements { count: Some(count) }), + Some(GpuCliRequest::DriverDefault) => Some(GpuResourceRequirements { count: None }), + None => None, + } +} + +fn parse_gpu_request(value: &str) -> std::result::Result { + if value.is_empty() { + return Ok(GpuCliRequest::DriverDefault); + } + + let count = value + .parse::() + .map_err(|_| "GPU count must be a positive integer".to_string())?; + if count == 0 { + return Err("GPU count must be greater than 0".to_string()); + } + + Ok(GpuCliRequest::Count(count)) +} + fn resolve_gateway_name(gateway_flag: &Option) -> Option { gateway_flag .clone() @@ -1212,8 +1242,11 @@ enum SandboxCommands { editor: Option, /// Request GPU resources for the sandbox. - #[arg(long)] - gpu: bool, + /// + /// Omit COUNT for the driver's default GPU selection, or pass COUNT + /// to request a specific number of GPUs. + #[arg(long, num_args = 0..=1, value_name = "COUNT", default_missing_value = "", value_parser = parse_gpu_request)] + gpu: Option, /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] @@ -2621,6 +2654,7 @@ async fn main() -> Result<()> { .map(|s| openshell_core::forward::ForwardSpec::parse(&s)) .transpose()?; let keep = keep || !no_keep || editor.is_some() || forward.is_some(); + let gpu_requirements = resolve_gpu_requirements(gpu); let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; let endpoint = &ctx.endpoint; @@ -2633,7 +2667,7 @@ async fn main() -> Result<()> { &ctx.name, &upload_specs, keep, - gpu, + gpu_requirements, cpu.as_deref(), memory.as_deref(), driver_config_json.as_deref(), @@ -3628,6 +3662,29 @@ mod tests { }); } + #[test] + fn resolve_gpu_requirements_handles_absent_gpu() { + let gpu = resolve_gpu_requirements(None); + + assert_eq!(gpu, None); + } + + #[test] + fn resolve_gpu_requirements_handles_driver_default() { + let gpu = resolve_gpu_requirements(Some(GpuCliRequest::DriverDefault)) + .expect("GPU requirement should be present"); + + assert_eq!(gpu.count, None); + } + + #[test] + fn resolve_gpu_requirements_handles_gpu_count() { + let gpu = resolve_gpu_requirements(Some(GpuCliRequest::Count(2))) + .expect("GPU requirement should be present"); + + assert_eq!(gpu.count, Some(2)); + } + #[test] fn apply_auth_uses_stored_token() { let tmp = tempfile::tempdir().unwrap(); @@ -4443,6 +4500,113 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_parses_driver_default() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu"]) + .expect("sandbox create --gpu should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::DriverDefault)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_parses_from_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "2"]) + .expect("sandbox create --gpu 2 should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_driver_default_allows_trailing_command() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "--", "claude"]) + .expect("sandbox create --gpu -- claude should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, command, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::DriverDefault)); + assert_eq!(command, vec!["claude".to_string()]); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_allows_trailing_command() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "2", + "--", + "claude", + ]) + .expect("sandbox create --gpu 2 -- claude should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, command, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + assert_eq!(command, vec!["claude".to_string()]); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "0"]); + + assert!(result.is_err(), "sandbox create --gpu 0 should be rejected"); + } + + #[test] + fn sandbox_create_gpu_count_accepts_equals_syntax() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu=2"]) + .expect("sandbox create --gpu=2 should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, .. }), + .. + }) => { + assert_eq!(gpu, Some(GpuCliRequest::Count(2))); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_non_integer() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "many"]); + + assert!( + result.is_err(), + "sandbox create --gpu many should be rejected" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index dbd240238..d242434fb 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,12 +39,13 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, - ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, - ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirements, + HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, + ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, + ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, + PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, ResourceRequirements, RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, @@ -120,7 +121,7 @@ fn ready_false_condition_message( fn provisioning_timeout_message( timeout_secs: u64, - requested_gpu: bool, + resource_requirements: Option<&ResourceRequirements>, condition_message: Option<&str>, ) -> String { let mut message = format!("sandbox provisioning timed out after {timeout_secs}s"); @@ -130,7 +131,7 @@ fn provisioning_timeout_message( message.push_str(condition_message); } - if requested_gpu { + if resource_requirements.is_some_and(|requirements| requirements.gpu.is_some()) { message.push_str( ". Hint: this may be because the available GPU is already in use by another sandbox.", ); @@ -1724,7 +1725,7 @@ pub async fn sandbox_create( gateway_name: &str, uploads: &[(String, Option, bool)], keep: bool, - gpu: bool, + gpu_requirements: Option, cpu: Option<&str>, memory: Option<&str>, driver_config_json: Option<&str>, @@ -1780,8 +1781,6 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu; - let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { Vec::new() @@ -1813,9 +1812,11 @@ pub async fn sandbox_create( None }; + let resource_requirements = gpu_requirements.map(|gpu| ResourceRequirements { gpu: Some(gpu) }); + let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, + resource_requirements, environment: environment.clone(), policy, providers: configured_providers, @@ -1960,7 +1961,7 @@ pub async fn sandbox_create( if remaining.is_zero() { let timeout_message = provisioning_timeout_message( provision_timeout.as_secs(), - requested_gpu, + resource_requirements.as_ref(), last_condition_message.as_deref(), ); if let Some(d) = display.as_mut() { @@ -1979,7 +1980,7 @@ pub async fn sandbox_create( // Timeout fired — the stream was idle for too long. let timeout_message = provisioning_timeout_message( provision_timeout.as_secs(), - requested_gpu, + resource_requirements.as_ref(), last_condition_message.as_deref(), ); if let Some(d) = display.as_mut() { @@ -7595,9 +7596,10 @@ mod tests { PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::{ - Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus, - ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrant, ProviderProfile, - ProviderProfileCredential, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, + GpuResourceRequirements, Provider, ProviderCredentialRefresh, + ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, + ProviderCredentialTokenGrant, ProviderProfile, ProviderProfileCredential, + ResourceRequirements, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, }; struct EnvVarGuard { @@ -8285,9 +8287,12 @@ mod tests { #[test] fn provisioning_timeout_message_includes_condition_and_gpu_hint() { + let resource_requirements = ResourceRequirements { + gpu: Some(GpuResourceRequirements { count: None }), + }; let message = provisioning_timeout_message( 120, - true, + Some(&resource_requirements), Some("DependenciesNotReady: Pod exists with phase: Pending; Service Exists"), ); @@ -8298,7 +8303,15 @@ mod tests { #[test] fn provisioning_timeout_message_omits_gpu_hint_for_non_gpu_requests() { - let message = provisioning_timeout_message(120, false, None); + let message = provisioning_timeout_message(120, None, None); + + assert_eq!(message, "sandbox provisioning timed out after 120s"); + } + + #[test] + fn provisioning_timeout_message_omits_gpu_hint_without_gpu_requirements() { + let resource_requirements = ResourceRequirements { gpu: None }; + let message = provisioning_timeout_message(120, Some(&resource_requirements), None); assert_eq!(message, "sandbox provisioning timed out after 120s"); } diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 7061614cb..4d3614d2c 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -18,13 +18,14 @@ use openshell_core::proto::{ ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, - GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, - ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, - ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, PlatformEvent, - ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, SandboxCondition, - SandboxLogLine, SandboxPhase, SandboxResponse, SandboxStatus, SandboxStreamEvent, - ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, - sandbox_stream_event, setting_value, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, GpuResourceRequirements, + HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, + ListSandboxesResponse, PlatformEvent, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, Sandbox, SandboxCondition, SandboxLogLine, SandboxPhase, + SandboxResponse, SandboxStatus, SandboxStreamEvent, ServiceStatus, SettingValue, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, + setting_value, }; use std::collections::HashMap; use std::fs; @@ -766,6 +767,10 @@ fn test_tls(server: &TestServer) -> TlsOptions { server.tls.with_gateway_name("openshell") } +fn gpu_requirements(count: Option) -> GpuResourceRequirements { + GpuResourceRequirements { count } +} + #[tokio::test] async fn sandbox_create_keeps_command_sessions_by_default() { let server = run_server().await; @@ -782,7 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { "openshell", &[], true, - false, + None, None, None, None, @@ -825,7 +830,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { "openshell", &[], true, - false, + None, Some("500m"), Some("2Gi"), None, @@ -902,7 +907,7 @@ async fn sandbox_create_sends_driver_config_json() { "openshell", &[], true, - false, + None, None, None, Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#), @@ -959,6 +964,98 @@ async fn sandbox_create_sends_driver_config_json() { ); } +#[tokio::test] +async fn sandbox_create_sends_gpu_default_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-default"), + None, + "openshell", + &[], + true, + Some(gpu_requirements(None)), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU requirement should be sent"); + + assert_eq!(gpu.count, None); +} + +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-two"), + None, + "openshell", + &[], + true, + Some(gpu_requirements(Some(2))), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU requirement should be sent"); + + assert_eq!(gpu.count, Some(2)); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -976,7 +1073,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { "openshell", &[], true, - false, + None, None, None, None, @@ -1034,7 +1131,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { "openshell", &[], true, - false, + None, None, None, None, @@ -1088,7 +1185,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { "openshell", &[], true, - false, + None, None, None, None, @@ -1134,7 +1231,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { "openshell", &[], true, - false, + None, None, None, None, @@ -1176,7 +1273,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { "openshell", &[], false, - false, + None, None, None, None, @@ -1222,7 +1319,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { "openshell", &[], false, - false, + None, None, None, None, @@ -1268,7 +1365,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { "openshell", &[], true, - false, + None, None, None, None, @@ -1314,7 +1411,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { "openshell", &[], false, - false, + None, None, None, None, @@ -1356,7 +1453,7 @@ async fn sandbox_create_sends_environment_variables() { "openshell", &[], true, - false, + None, None, None, None, diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 9718b50f2..fb4927aac 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -1,24 +1,107 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Shared GPU request helpers. +//! Shared GPU resource requirement helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::ResourceRequirements as SandboxResourceRequirements; +use crate::proto::compute::v1::{ + GpuResourceRequirements as DriverGpuResourceRequirements, + ResourceRequirements as DriverResourceRequirements, +}; -/// Resolve a GPU request into CDI device identifiers. +/// Return whether sandbox resource requirements request a GPU. +#[must_use] +pub fn public_gpu_requested(resources: Option<&SandboxResourceRequirements>) -> bool { + resources + .and_then(|resources| resources.gpu.as_ref()) + .is_some() +} + +/// Return the requested sandbox GPU count, if one was specified. +#[must_use] +pub fn public_gpu_count(resources: Option<&SandboxResourceRequirements>) -> Option { + resources + .and_then(|resources| resources.gpu.as_ref()) + .and_then(|gpu| gpu.count) +} + +/// Return whether compute-driver resource requirements request a GPU. +#[must_use] +pub fn driver_gpu_requested(resources: Option<&DriverResourceRequirements>) -> bool { + driver_gpu_requirements(resources).is_some() +} + +/// Return the requested compute-driver GPU count, if one was specified. +#[must_use] +pub fn driver_gpu_count(resources: Option<&DriverResourceRequirements>) -> Option { + driver_gpu_requirements(resources).and_then(|gpu| gpu.count) +} + +/// Return the requested compute-driver GPU requirements, if present. +#[must_use] +pub fn driver_gpu_requirements( + resources: Option<&DriverResourceRequirements>, +) -> Option<&DriverGpuResourceRequirements> { + resources.and_then(|resources| resources.gpu.as_ref()) +} + +/// Resolve a compute-driver GPU request into CDI device identifiers. /// /// `None` means no GPU was requested. A GPU request with no explicit CDI /// devices uses the CDI all-GPU request; otherwise the driver-configured CDI /// devices pass through unchanged. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, cdi_devices: &[String]) -> Option> { - gpu.then(|| { - if cdi_devices.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - cdi_devices.to_vec() +pub fn cdi_gpu_device_ids( + gpu: Option<&DriverGpuResourceRequirements>, + cdi_devices: &[String], +) -> Option> { + match gpu { + Some(_) if cdi_devices.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(_) => Some(cdi_devices.to_vec()), + None => None, + } +} + +/// Validate a compute-driver GPU request against driver-owned specific devices. +/// +/// Drivers call this when a sandbox request combines portable GPU requirements +/// with exact device identifiers in `driver_config`. +/// +/// # Errors +/// Returns an error when the sandbox GPU request is absent or when `gpu.count` +/// does not equal the number of specific devices. A single exact device is +/// compatible with the default sandbox GPU request where `gpu.count` is absent. +pub fn validate_specific_gpu_device_request( + gpu: Option<&DriverGpuResourceRequirements>, + specific_devices: &[String], + driver_config_field: &str, +) -> Result<(), String> { + let device_count = specific_devices.len(); + if device_count == 0 { + return Ok(()); + } + + let Some(gpu) = gpu else { + return Err(format!("{driver_config_field} requires a gpu request")); + }; + + let Some(count) = gpu.count else { + if device_count == 1 { + return Ok(()); } - }) + return Err(format!( + "{driver_config_field} requires an explicit gpu count matching its length ({device_count})" + )); + }; + + if usize::try_from(count).ok() != Some(device_count) { + return Err(format!( + "gpu count ({count}) must match {driver_config_field} length ({device_count})" + )); + } + + Ok(()) } #[cfg(test)] @@ -27,22 +110,26 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, &[]), None); + assert_eq!(cdi_gpu_device_ids(None, &[]), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let gpu = DriverGpuResourceRequirements { count: None }; + assert_eq!( - cdi_gpu_device_ids(true, &[]), + cdi_gpu_device_ids(Some(&gpu), &[]), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] fn cdi_gpu_device_ids_passes_explicit_device_ids_through() { + let gpu = DriverGpuResourceRequirements { count: None }; + assert_eq!( cdi_gpu_device_ids( - true, + Some(&gpu), &[ "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string() @@ -54,4 +141,92 @@ mod tests { ]) ); } + + #[test] + fn validate_specific_gpu_device_request_ignores_empty_devices() { + validate_specific_gpu_device_request(None, &[], "driver_config.cdi_devices") + .expect("empty exact device lists should not be validated"); + } + + #[test] + fn validate_specific_gpu_device_request_accepts_matching_count() { + let gpu = DriverGpuResourceRequirements { count: Some(2) }; + let specific_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect("matching count should be accepted"); + } + + #[test] + fn validate_specific_gpu_device_request_accepts_missing_count_for_one_device() { + let gpu = DriverGpuResourceRequirements { count: None }; + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect("single exact device should be compatible with a default GPU request"); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_missing_gpu_request() { + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + let err = validate_specific_gpu_device_request( + None, + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("missing GPU request should be rejected"); + + assert_eq!(err, "driver_config.cdi_devices requires a gpu request"); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_missing_count_for_multiple_devices() { + let gpu = DriverGpuResourceRequirements { count: None }; + let specific_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + let err = validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("missing count should be rejected for multiple devices"); + + assert_eq!( + err, + "driver_config.cdi_devices requires an explicit gpu count matching its length (2)" + ); + } + + #[test] + fn validate_specific_gpu_device_request_rejects_mismatch() { + let gpu = DriverGpuResourceRequirements { count: Some(2) }; + let specific_devices = vec!["nvidia.com/gpu=0".to_string()]; + + let err = validate_specific_gpu_device_request( + Some(&gpu), + &specific_devices, + "driver_config.cdi_devices", + ) + .expect_err("mismatched count should be rejected"); + + assert_eq!( + err, + "gpu count (2) must match driver_config.cdi_devices length (1)" + ); + } } diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index 7f74cbe17..d16f53456 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-only GPU requests are rejected; exact CDI device lists with more than one entry require an explicit GPU count matching the device list length. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 05137a2b0..f77205e6b 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -27,7 +27,9 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, supervisor_image_should_refresh, }; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{ + cdi_gpu_device_ids, driver_gpu_requirements, validate_specific_gpu_device_request, +}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -36,11 +38,11 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + GetSandboxResponse, GpuResourceRequirements, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto_struct::{ deserialize_optional_non_empty_string_list, struct_to_json_value, @@ -461,7 +463,8 @@ impl DockerComputeDriver { let driver_config = DockerSandboxDriverConfig::from_template(template).map_err(Status::invalid_argument)?; - Self::validate_gpu_request(spec.gpu, config.supports_gpu, &driver_config)?; + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + Self::validate_gpu_request(gpu_requirements, config.supports_gpu, &driver_config)?; Ok(()) } @@ -509,21 +512,29 @@ impl DockerComputeDriver { } fn validate_gpu_request( - gpu: bool, + gpu_requirements: Option<&GpuResourceRequirements>, supports_gpu: bool, driver_config: &DockerSandboxDriverConfig, ) -> Result<(), Status> { - if !gpu && driver_config.cdi_devices.is_some() { - return Err(Status::invalid_argument( - "driver_config.cdi_devices requires gpu=true", + if gpu_requirements.is_some() && !supports_gpu { + return Err(Status::failed_precondition( + "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); } - if gpu && !supports_gpu { - return Err(Status::failed_precondition( - "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", + if let Some(cdi_devices) = driver_config.cdi_devices.as_deref() { + validate_specific_gpu_device_request( + gpu_requirements, + cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(Status::invalid_argument)?; + } else if gpu_requirements.and_then(|gpu| gpu.count).is_some() { + return Err(Status::invalid_argument( + "docker GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices", )); } + Ok(()) } @@ -2121,14 +2132,16 @@ fn build_device_requests(sandbox: &DriverSandbox) -> Result DriverSandbox { environment: HashMap::from([("TEMPLATE_ENV".to_string(), "template".to_string())]), ..Default::default() }), - gpu: false, + resource_requirements: None, sandbox_token: String::new(), }), status: None, @@ -79,6 +80,12 @@ fn list_string_driver_config(field: &str, values: &[&str]) -> prost_types::Struc } } +fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } +} + fn runtime_config() -> DockerDriverRuntimeConfig { DockerDriverRuntimeConfig { default_image: "image:latest".to_string(), @@ -1007,7 +1014,21 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(None)); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!(err.message().contains("Docker CDI")); +} + +#[test] +fn validate_sandbox_rejects_missing_gpu_support_before_request_shape() { + let config = runtime_config(); + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -1020,7 +1041,7 @@ fn validate_sandbox_rejects_invalid_cdi_devices_before_gpu_capability() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -1035,7 +1056,7 @@ fn validate_sandbox_rejects_unknown_driver_config_fields() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])); @@ -1045,12 +1066,116 @@ fn validate_sandbox_rejects_unknown_driver_config_fields() { assert!(err.message().contains("unknown field")); } +#[test] +fn validate_sandbox_rejects_gpu_count_request() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(Some(2))); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("GPU count requests are not supported") + ); +} + +#[test] +fn validate_sandbox_accepts_gpu_count_matching_cdi_devices() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); + + DockerComputeDriver::validate_sandbox(&sandbox, &config) + .expect("matching explicit CDI device count should be accepted"); +} + +#[test] +fn validate_sandbox_accepts_single_cdi_device_without_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(None)); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); + + DockerComputeDriver::validate_sandbox(&sandbox, &config) + .expect("single exact CDI device should be compatible with a default GPU request"); +} + +#[test] +fn validate_sandbox_rejects_multiple_cdi_devices_without_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(None)); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("requires an explicit gpu count matching its length (2)") + ); +} + +#[test] +fn validate_sandbox_rejects_cdi_devices_without_gpu_request() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox + .spec + .as_mut() + .unwrap() + .template + .as_mut() + .unwrap() + .driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("requires a gpu request")); +} + +#[test] +fn validate_sandbox_rejects_gpu_count_mismatched_cdi_devices() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); +} + #[test] fn validate_sandbox_rejects_template_errors_before_device_config() { let config = runtime_config(); let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); let template = spec.template.as_mut().unwrap(); template.agent_socket_path = "/tmp/agent.sock".to_string(); template.driver_config = Some(cdi_devices_config(&[])); @@ -1088,7 +1213,7 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resources(None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -1111,7 +1236,7 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { config.supports_gpu = true; let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); let create_body = build_container_create_body(&sandbox, &config).unwrap(); @@ -1130,7 +1255,25 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { } #[test] -fn build_container_create_body_rejects_cdi_devices_without_gpu() { +fn build_container_create_body_rejects_gpu_count_mismatched_cdi_devices() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resources(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); + + let err = build_container_create_body(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); +} + +#[test] +fn build_container_create_body_rejects_cdi_devices_without_gpu_request() { let mut sandbox = test_sandbox(); sandbox .spec @@ -1143,14 +1286,14 @@ fn build_container_create_body_rejects_cdi_devices_without_gpu() { let err = build_container_create_body(&sandbox, &runtime_config()).unwrap_err(); assert_eq!(err.code(), tonic::Code::InvalidArgument); - assert!(err.message().contains("requires gpu=true")); + assert!(err.message().contains("requires a gpu request")); } #[test] fn build_container_create_body_rejects_empty_cdi_devices() { let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; + spec.resource_requirements = Some(gpu_resources(None)); spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); let err = build_container_create_body(&sandbox, &runtime_config()).unwrap_err(); diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 6ad0b27c8..3cdb9fa57 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -62,9 +62,9 @@ the supervisor's network namespace mount setup on AppArmor-enabled nodes. ## GPU Support When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +for `nvidia.com/gpu` and requests the configured GPU count in the workload spec. +When no count is set, the driver requests one GPU resource. The sandbox image +must provide the user-space libraries needed by the agent workload. ## Driver Config POC @@ -97,5 +97,6 @@ POC parser renders the keys listed above and rejects unknown fields. `pod.runtime_class_name` maps to PodSpec `runtimeClassName` and overrides the driver's configured `default_runtime_class_name`; the typed public `SandboxTemplate.runtime_class_name` still takes precedence when set. Use the -public `gpu` flag for the default GPU request and `driver_config` only for -additional driver-owned resource details. +public `--gpu` flag for the default GPU request, pass a count to `--gpu` for +counted GPU requests, and use `driver_config` only for additional driver-owned +resource details. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index dc636efc3..b0c288d0d 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -17,6 +17,7 @@ use kube::{Client, Error as KubeError}; use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, SUPERVISOR_IMAGE_BINARY_PATH, }; +use openshell_core::gpu::driver_gpu_requirements; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -25,8 +26,9 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, - WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, + GetCapabilitiesResponse, GpuResourceRequirements, WatchSandboxesDeletedEvent, + WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, + watch_sandboxes_event, }; use openshell_core::proto_struct::{struct_to_json_object, value_to_json}; use serde::Deserialize; @@ -281,8 +283,12 @@ impl KubernetesComputeDriver { pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(tonic::Status::invalid_argument)?; - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - if gpu_requested + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + validate_gpu_request(gpu_requirements)?; + if gpu_requirements.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -376,6 +382,13 @@ impl KubernetesComputeDriver { pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) .map_err(KubernetesDriverError::InvalidArgument)?; + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + validate_gpu_request(gpu_requirements).map_err(|status| { + KubernetesDriverError::InvalidArgument(status.message().to_string()) + })?; let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -638,6 +651,18 @@ impl KubernetesComputeDriver { } } +fn validate_gpu_request( + gpu_requirements: Option<&GpuResourceRequirements>, +) -> Result<(), tonic::Status> { + if gpu_requirements.and_then(|gpu| gpu.count) == Some(0) { + return Err(tonic::Status::invalid_argument( + "gpu count must be greater than 0", + )); + } + + Ok(()) +} + fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { let mut labels = BTreeMap::new(); labels.insert(LABEL_SANDBOX_ID.to_string(), sandbox.id.clone()); @@ -1201,7 +1226,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s_with_gpu_requirements( + template, + driver_gpu_requirements(spec.resource_requirements.as_ref()), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1231,9 +1262,9 @@ fn sandbox_to_k8s_spec( let pod_env = spec_pod_env(spec); root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s( + sandbox_template_to_k8s_with_gpu_requirements( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + driver_gpu_requirements(spec.and_then(|s| s.resource_requirements.as_ref())), &pod_env, inject_workspace, params, @@ -1246,12 +1277,30 @@ fn sandbox_to_k8s_spec( ) } +#[cfg(test)] fn sandbox_template_to_k8s( template: &SandboxTemplate, gpu: bool, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, +) -> serde_json::Value { + let gpu_requirements = gpu.then_some(GpuResourceRequirements { count: None }); + sandbox_template_to_k8s_with_gpu_requirements( + template, + gpu_requirements.as_ref(), + spec_environment, + inject_workspace, + params, + ) +} + +fn sandbox_template_to_k8s_with_gpu_requirements( + template: &SandboxTemplate, + gpu_requirements: Option<&GpuResourceRequirements>, + spec_environment: &std::collections::HashMap, + inject_workspace: bool, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let driver_config = kubernetes_driver_config(template); @@ -1331,7 +1380,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu_requirements.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1440,7 +1489,7 @@ fn sandbox_template_to_k8s( serde_json::Value::Array(volume_mounts), ); - if let Some(resources) = container_resources(template, gpu) { + if let Some(resources) = container_resources(template, gpu_requirements) { container.insert("resources".to_string(), resources); } apply_agent_driver_resources(&mut container, &driver_config.containers.agent.resources); @@ -1618,7 +1667,10 @@ fn app_armor_profile_to_k8s(profile: &AppArmorProfile) -> serde_json::Value { value } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu_requirements: Option<&GpuResourceRequirements>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1651,8 +1703,12 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option> = @@ -2000,10 +2054,9 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, template: Some(SandboxTemplate { driver_config: Some(json_struct(serde_json::json!({ - "cdi_devices": ["nvidia.com/gpu=0"] + "gpu_device_ids": ["0000:2d:00.0"] }))), ..Default::default() }), @@ -2014,6 +2067,28 @@ mod tests { let err = KubernetesSandboxDriverConfig::from_sandbox(&sandbox).unwrap_err(); assert!(err.contains("unknown field")); + assert!(err.contains("gpu_device_ids")); + } + + #[test] + fn validate_rejects_zero_gpu_count() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + resource_requirements: Some(ResourceRequirements { + gpu: Some(GpuResourceRequirements { count: Some(0) }), + }), + ..SandboxSpec::default() + }), + ..Sandbox::default() + }; + + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| driver_gpu_requirements(spec.resource_requirements.as_ref())); + let err = validate_gpu_request(gpu_requirements).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("gpu count must be greater than 0")); } #[test] @@ -2345,6 +2420,26 @@ mod tests { ); } + #[test] + fn gpu_count_sandbox_adds_requested_gpu_limit() { + let pod_template = { + let params = SandboxPodParams::default(); + let gpu_requirements = GpuResourceRequirements { count: Some(2) }; + sandbox_template_to_k8s_with_gpu_requirements( + &SandboxTemplate::default(), + Some(&gpu_requirements), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") + ); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 68a223bde..c1400d918 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | `driver_config.cdi_devices` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | `driver_config.cdi_devices` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-only GPU requests are rejected; exact CDI device lists with more than one entry require an explicit GPU count matching the device list length. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index afcc17585..00fb375e1 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -5,7 +5,9 @@ use crate::config::PodmanComputeConfig; use openshell_core::ComputeDriverError; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{ + cdi_gpu_device_ids, driver_gpu_requirements, validate_specific_gpu_device_request, +}; use openshell_core::proto::compute::v1::{DriverSandbox, DriverSandboxTemplate}; use openshell_core::proto_struct::deserialize_optional_non_empty_string_list; use openshell_core::{driver_mounts, proto_struct}; @@ -484,14 +486,16 @@ fn build_devices(sandbox: &DriverSandbox) -> Result>, Co let cdi_devices = PodmanSandboxDriverConfig::from_sandbox(sandbox)? .cdi_devices .unwrap_or_default(); - if !spec.gpu && !cdi_devices.is_empty() { - return Err(ComputeDriverError::InvalidArgument( - "driver_config.cdi_devices requires gpu=true".to_string(), - )); - } + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + validate_specific_gpu_device_request( + gpu_requirements, + &cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(ComputeDriverError::InvalidArgument)?; Ok( - cdi_gpu_device_ids(spec.gpu, &cdi_devices).map(|device_ids| { + cdi_gpu_device_ids(gpu_requirements, &cdi_devices).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -1092,6 +1096,7 @@ fn parse_memory_to_bytes(quantity: &str) -> Option { #[cfg(test)] mod tests { use super::*; + use openshell_core::proto::compute::v1::{GpuResourceRequirements, ResourceRequirements}; static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); @@ -1133,6 +1138,12 @@ mod tests { } } + fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } + } + #[test] fn parse_cpu_millicore() { assert_eq!(parse_cpu_to_microseconds("500m"), Some(50_000)); @@ -1246,7 +1257,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }); let config = test_config(); @@ -1264,7 +1275,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), ..Default::default() @@ -1281,7 +1292,60 @@ mod tests { } #[test] - fn container_spec_rejects_cdi_devices_without_gpu() { + fn container_spec_accepts_gpu_count_matching_cdi_devices() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!(spec["devices"].as_array().map(Vec::len), Some(2)); + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some("nvidia.com/gpu=0") + ); + assert_eq!( + spec["devices"][1]["path"].as_str(), + Some("nvidia.com/gpu=1") + ); + } + + #[test] + fn container_spec_rejects_gpu_count_mismatched_cdi_devices() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + + let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); + } + + #[test] + fn container_spec_rejects_cdi_devices_without_gpu_request() { use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; let mut sandbox = test_sandbox("test-id", "test-name"); @@ -1296,7 +1360,7 @@ mod tests { let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); - assert!(err.to_string().contains("requires gpu=true")); + assert!(err.to_string().contains("requires a gpu request")); } #[test] @@ -1305,7 +1369,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_devices_config(&[])), ..Default::default() @@ -1325,7 +1389,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(DriverSandboxTemplate { driver_config: Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])), ..Default::default() diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 6f9762c15..17864c941 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -11,7 +11,10 @@ use crate::watcher::{ }; use openshell_core::ComputeDriverError; use openshell_core::driver_utils::supervisor_image_should_refresh; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::gpu::{driver_gpu_requirements, validate_specific_gpu_device_request}; +use openshell_core::proto::compute::v1::{ + DriverSandbox, GetCapabilitiesResponse, GpuResourceRequirements, +}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -281,24 +284,54 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); + let gpu_requirements = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| driver_gpu_requirements(Some(requirements))); let driver_config = PodmanSandboxDriverConfig::from_sandbox(sandbox)?; - if !gpu_requested && driver_config.cdi_devices.is_some() { - return Err(ComputeDriverError::InvalidArgument( - "driver_config.cdi_devices requires gpu=true".to_string(), - )); - } - Self::validate_gpu_request(gpu_requested)?; + let cdi_devices = driver_config.cdi_devices.as_deref(); + Self::validate_gpu_request(gpu_requirements, cdi_devices)?; self.validate_user_volume_mounts_available(sandbox).await?; Ok(()) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + fn validate_gpu_request( + gpu_requirements: Option<&GpuResourceRequirements>, + cdi_devices: Option<&[String]>, + ) -> Result<(), ComputeDriverError> { + Self::validate_gpu_request_with_capacity( + gpu_requirements, + cdi_devices, + Self::has_gpu_capacity(), + ) + } + + fn validate_gpu_request_with_capacity( + gpu_requirements: Option<&GpuResourceRequirements>, + cdi_devices: Option<&[String]>, + has_gpu_capacity: bool, + ) -> Result<(), ComputeDriverError> { + if gpu_requirements.is_some() && !has_gpu_capacity { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); } + + if let Some(cdi_devices) = cdi_devices { + validate_specific_gpu_device_request( + gpu_requirements, + cdi_devices, + "driver_config.cdi_devices", + ) + .map_err(ComputeDriverError::InvalidArgument)?; + } else if gpu_requirements.and_then(|gpu| gpu.count).is_some() { + return Err(ComputeDriverError::InvalidArgument( + "podman GPU count requests are not supported; use --gpu without a count or driver_config.cdi_devices" + .to_string(), + )); + } + Ok(()) } @@ -693,6 +726,100 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_gpu_count() { + let gpu = GpuResourceRequirements { count: Some(2) }; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity(Some(&gpu), None, true) + .expect_err("gpu count should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("GPU count requests are not supported") + ); + } + + #[test] + fn validate_gpu_request_accepts_single_cdi_device_without_gpu_count() { + let gpu = GpuResourceRequirements { count: None }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + + PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect("single exact CDI device should pass count validation"); + } + + #[test] + fn validate_gpu_request_rejects_missing_gpu_capacity_before_request_shape() { + let gpu = GpuResourceRequirements { count: Some(2) }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + false, + ) + .expect_err("missing GPU capacity should be rejected before request shape"); + + assert!(matches!(err, ComputeDriverError::Precondition(_))); + assert!(err.to_string().contains("no NVIDIA GPU devices")); + } + + #[test] + fn validate_gpu_request_rejects_multiple_cdi_devices_without_gpu_count() { + let gpu = GpuResourceRequirements { count: None }; + let cdi_devices = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect_err("missing CDI device count should be rejected for multiple devices"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("requires an explicit gpu count matching its length (2)") + ); + } + + #[test] + fn validate_gpu_request_rejects_cdi_devices_without_gpu_request() { + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + None, + Some(&cdi_devices), + false, + ) + .expect_err("missing GPU request should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!(err.to_string().contains("requires a gpu request")); + } + + #[test] + fn validate_gpu_request_rejects_mismatched_cdi_device_count() { + let gpu = GpuResourceRequirements { count: Some(2) }; + let cdi_devices = vec!["nvidia.com/gpu=0".to_string()]; + let err = PodmanComputeDriver::validate_gpu_request_with_capacity( + Some(&gpu), + Some(&cdi_devices), + true, + ) + .expect_err("mismatched CDI device count should be rejected"); + + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!( + err.to_string() + .contains("gpu count (2) must match driver_config.cdi_devices length (1)") + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 30fecd8be..16a9742a5 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -29,6 +29,7 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::gpu::{driver_gpu_requirements, validate_specific_gpu_device_request}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -627,7 +628,12 @@ impl VmDriver { overlay_preparation: OverlayPreparation, ) -> Result<(), Status> { self.ensure_provisioning_active(&sandbox.id).await?; - let is_gpu = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + let is_gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| driver_gpu_requirements(Some(requirements))) + .is_some(); self.publish_platform_event( sandbox.id.clone(), platform_event( @@ -3079,7 +3085,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu if let Some(template) = spec.template.as_ref() { validate_vm_sandbox_template(template)?; } - validate_vm_gpu_request(sandbox, gpu_enabled)?; + validate_gpu_request(sandbox, gpu_enabled)?; Ok(()) } @@ -3100,18 +3106,33 @@ fn validate_vm_sandbox_template(template: &SandboxTemplate) -> Result<(), Status } #[allow(clippy::result_large_err)] -fn validate_vm_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { +fn validate_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { let spec = sandbox .spec .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - let _ = vm_gpu_device_id(sandbox)?; - if spec.gpu && !gpu_enabled { + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + let gpu_count = gpu_requirements.and_then(|gpu| gpu.count); + + if gpu_requirements.is_some() && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); } + + if gpu_count == Some(0) { + return Err(Status::invalid_argument("gpu count must be greater than 0")); + } + + let _ = vm_gpu_device_id(sandbox)?; + + if gpu_count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "VM GPU sandboxes support only one GPU", + )); + } + Ok(()) } @@ -3124,19 +3145,21 @@ fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { .map_err(Status::invalid_argument)? .gpu_device_ids .unwrap_or_default(); - if !spec.gpu && !gpu_device_ids.is_empty() { - return Err(Status::invalid_argument( - "driver_config.gpu_device_ids requires gpu=true", - )); - } + let gpu_requirements = driver_gpu_requirements(spec.resource_requirements.as_ref()); + validate_specific_gpu_device_request( + gpu_requirements, + &gpu_device_ids, + "driver_config.gpu_device_ids", + ) + .map_err(Status::invalid_argument)?; if gpu_device_ids.len() > 1 { return Err(Status::invalid_argument( "vm driver currently supports at most one gpu_device_ids entry", )); } - Ok(spec - .gpu + Ok(gpu_requirements + .is_some() .then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) } @@ -5064,6 +5087,7 @@ mod tests { }; use openshell_core::proto::compute::v1::{ DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, + GpuResourceRequirements, ResourceRequirements, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -5102,6 +5126,12 @@ mod tests { } } + fn gpu_resources(count: Option) -> ResourceRequirements { + ResourceRequirements { + gpu: Some(GpuResourceRequirements { count }), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -5169,7 +5199,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }), ..Default::default() @@ -5180,12 +5210,32 @@ mod tests { assert!(err.message().contains("GPU support is not enabled")); } + #[test] + fn validate_vm_sandbox_rejects_missing_gpu_support_before_request_shape() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, false) + .expect_err("missing GPU support should be rejected before request shape"); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("GPU support is not enabled")); + } + #[test] fn validate_vm_sandbox_accepts_gpu_when_enabled() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), ..Default::default() }), ..Default::default() @@ -5194,11 +5244,123 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_accepts_gpu_count_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(1))), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("one GPU should be accepted when enabled"); + } + + #[test] + fn validate_vm_sandbox_accepts_single_gpu_device_without_gpu_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(None)), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true) + .expect("single exact GPU device should be compatible with a default GPU request"); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids_without_gpu_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(None)), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU device IDs without count should be rejected"); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("requires an explicit gpu count matching its length (2)") + ); + } + + #[test] + fn validate_vm_sandbox_accepts_gpu_count_matching_device_id() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(1))), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true) + .expect("matching explicit GPU device count should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_above_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU VM request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("support only one GPU")); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_mismatched_device_id() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resources(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("mismatched explicit GPU device count should be rejected"); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("gpu count (2) must match driver_config.gpu_device_ids length (1)") + ); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_device_without_gpu_request() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), ..Default::default() @@ -5208,9 +5370,9 @@ mod tests { ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device_ids without gpu should be rejected"); + .expect_err("gpu_device_ids without a GPU request should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device_ids requires gpu=true")); + assert!(err.message().contains("requires a gpu request")); } #[test] @@ -5218,7 +5380,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(Some(2))), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), ..Default::default() @@ -5238,7 +5400,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { driver_config: Some(gpu_device_ids_config(&[])), ..Default::default() @@ -5258,7 +5420,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { driver_config: Some(gpu_device_id_typo_config(&["0000:2d:00.0"])), ..Default::default() @@ -5278,7 +5440,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resources(None)), template: Some(SandboxTemplate { agent_socket_path: "/tmp/agent.sock".to_string(), driver_config: Some(gpu_device_ids_config(&[])), diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 812b9c59a..5c3e53860 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -19,10 +19,11 @@ use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, - ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, - compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, + GpuResourceRequirements as DriverGpuResourceRequirements, ListSandboxesRequest, + ResourceRequirements as DriverSandboxResourceRequirements, ValidateSandboxCreateRequest, + WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -1279,7 +1280,14 @@ fn driver_sandbox_spec_from_public( .as_ref() .map(|template| driver_sandbox_template_from_public(template, driver_kind)) .transpose()?, - gpu: spec.gpu, + resource_requirements: spec.resource_requirements.as_ref().map(|requirements| { + DriverSandboxResourceRequirements { + gpu: requirements + .gpu + .as_ref() + .map(|gpu| DriverGpuResourceRequirements { count: gpu.count }), + } + }), sandbox_token: String::new(), }) } @@ -1660,7 +1668,9 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec + .and_then(|sandbox_spec| sandbox_spec.resource_requirements.as_ref()) + .is_some_and(|requirements| openshell_core::gpu::public_gpu_requested(Some(requirements))); if !gpu_requested { return; } @@ -1856,6 +1866,26 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_requirement() { + let public = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(2) }), + }), + ..Default::default() + }; + + let driver = + driver_sandbox_spec_from_public(&public, None).expect("driver spec should map"); + + let gpu = driver + .resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("driver GPU requirement should be set"); + assert_eq!(gpu.count, Some(2)); + } + #[test] fn select_driver_config_forwards_only_matching_driver_block() { let config = prost_types::Struct { @@ -2355,7 +2385,9 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }), ); @@ -2383,13 +2415,7 @@ mod tests { ..Default::default() }); - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: false, - ..Default::default() - }), - ); + rewrite_user_facing_conditions(&mut status, Some(&SandboxSpec::default())); assert_eq!(status.unwrap().conditions[0].message, original); } @@ -2668,7 +2694,9 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2691,7 +2719,9 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!(stored.spec.as_ref().is_some_and(|spec| { + openshell_core::gpu::public_gpu_requested(spec.resource_requirements.as_ref()) + })); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index e60ce3995..2817e7381 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -98,9 +98,11 @@ fn emit_sandbox_create_telemetry( } else { SandboxTemplateSource::Default }; + let gpu_requested = + openshell_core::gpu::public_gpu_requested(spec.resource_requirements.as_ref()); openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + gpu_requested, spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 03a69d6e9..16b05a90b 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -134,6 +134,9 @@ pub(super) fn validate_sandbox_spec( validate_env_entries(&tmpl.environment, "spec.template.environment")?; } + // --- spec.resource_requirements.gpu --- + validate_gpu_request_fields(spec)?; + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -147,6 +150,14 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_request_fields(spec: &openshell_core::proto::SandboxSpec) -> Result<(), Status> { + if openshell_core::gpu::public_gpu_count(spec.resource_requirements.as_ref()) == Some(0) { + return Err(Status::invalid_argument("gpu count must be greater than 0")); + } + + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -760,12 +771,38 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: None }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(2) }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(openshell_core::proto::ResourceRequirements { + gpu: Some(openshell_core::proto::GpuResourceRequirements { count: Some(0) }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu count must be greater than 0")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index caf8e9961..c75ec3438 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -53,14 +53,19 @@ openshell sandbox create \ ``` Driver config is for fields without a stable public flag. Prefer `--cpu`, -`--memory`, and `--gpu` for portable resource intent. +`--memory`, and `--gpu` for supported resource intent. Pass a count to `--gpu` +when the active driver supports counted allocation. Docker and Podman reject +count-only GPU selection. If `driver_config` lists more than one exact CDI +device, pass `--gpu COUNT`; the count must match the number of listed devices. +A single exact CDI device is compatible with the default `--gpu` request. Exact GPU device selection remains driver-owned and requires `--gpu`. Docker and Podman accept `cdi_devices`; replace the top-level `docker` key with `podman` when using the Podman driver, for example `{"docker":{"cdi_devices":["nvidia.com/gpu=0"]}}`. The VM driver accepts `gpu_device_ids`, for example `{"vm":{"gpu_device_ids":["0000:2d:00.0"]}}`; -the current VM implementation accepts at most one entry. +the current VM implementation accepts at most one entry and allows either +`--gpu` or `--gpu 1` when `gpu_device_ids` is set. For Kubernetes, `pod.runtime_class_name` maps to PodSpec `runtimeClassName`. It overrides the gateway's configured default runtime class for that sandbox, diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 1a54d0a06..c531e9d15 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -70,6 +70,20 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs by passing a count to `--gpu`: + +```shell +openshell sandbox create --gpu 2 -- claude +``` + +Support for counted GPU requests is driver-dependent. Kubernetes honors a +counted `--gpu` request by setting the `nvidia.com/gpu` limit. Docker and Podman +reject count-only selection. If `driver_config` lists more than one exact CDI +device, pass `--gpu COUNT`; the count must match the number of listed devices. +A single exact CDI device is compatible with the default `--gpu` request. VM +gateways accept only one GPU, either through `--gpu` or `--gpu 1`; a single +`gpu_device_ids` entry works with either form. + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index dbcb9e818..679433a6f 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -83,8 +83,9 @@ message DriverSandboxSpec { map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; + // Portable resource requirements used by the gateway for driver selection + // and by drivers for provisioning. + ResourceRequirements resource_requirements = 9; reserved 10; reserved "gpu_device"; // Gateway-minted JWT identifying this sandbox to the gateway. Set by @@ -96,6 +97,18 @@ message DriverSandboxSpec { string sandbox_token = 11; } +message ResourceRequirements { + // GPU requirements for the sandbox. Presence indicates a GPU request. + GpuResourceRequirements gpu = 1; +} + +// Driver GPU resource requirements. +message GpuResourceRequirements { + // Optional number of GPUs requested. When omitted, the driver uses its + // default GPU assignment behavior. + optional uint32 count = 1; +} + // Driver-owned runtime template consumed by the compute platform. // // This message describes the sandbox workload in backend-neutral terms. diff --git a/proto/openshell.proto b/proto/openshell.proto index d701956d3..fc8975b02 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -317,8 +317,9 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; + // Portable resource requirements used by the gateway for driver selection + // and by drivers for provisioning. + ResourceRequirements resource_requirements = 9; reserved 10; reserved "gpu_device"; // Field 11 was `proposal_approval_mode`. The approval mode is now a @@ -329,6 +330,18 @@ message SandboxSpec { reserved "proposal_approval_mode"; } +message ResourceRequirements { + // GPU requirements for the sandbox. Presence indicates a GPU request. + GpuResourceRequirements gpu = 1; +} + +// Public GPU resource requirements. +message GpuResourceRequirements { + // Optional number of GPUs requested. When omitted, the driver uses its + // default GPU assignment behavior. + optional uint32 count = 1; +} + // Public sandbox template mapped onto compute-driver template inputs. message SandboxTemplate { // Fully-qualified OCI image reference used to boot the sandbox. diff --git a/rfc/0004-sandbox-resource-requirements/README.md b/rfc/0004-sandbox-resource-requirements/README.md index 01b4319dd..e18c97bb7 100644 --- a/rfc/0004-sandbox-resource-requirements/README.md +++ b/rfc/0004-sandbox-resource-requirements/README.md @@ -69,8 +69,10 @@ tracked separately in issue #1492. - Defining the general driver-specific configuration passthrough API. Issue #1492 tracks that related API surface. - Publishing allocated resource identities in sandbox status. -- Preserving long-term compatibility for `gpu`, `gpu_device`, or a - GPU-specific `gpu_count` request field. +- Preserving alpha-era compatibility for `gpu`, `gpu_device`, or a + GPU-specific `gpu_count` request field. The legacy GPU-specific request + fields are intentionally not carried forward into the API shape this RFC + aims to stabilize. ## Proposal @@ -89,13 +91,22 @@ message SandboxSpec { // Portable resource requirements used by the gateway for driver selection // and by drivers for provisioning. - SandboxResourceRequirements resource_requirements = 11; + SandboxResourceRequirements resource_requirements = 9; - reserved 9, 10; - reserved "gpu", "gpu_device"; + reserved 10; + reserved "gpu_device"; } ``` +The public sandbox API is still alpha. This migration intentionally replaces +the old `bool gpu = 9` field with the typed `resource_requirements = 9` message +instead of reserving the legacy field number. Old live requests and persisted +sandbox records that encode GPU intent through the legacy boolean are not +migrated; callers should use a matching OpenShell CLI/API version and recreate +GPU sandboxes after upgrade when they need the new typed shape. Avoiding +alpha-era reserved fields keeps the proto surface closer to the API intended +for stabilization. + `SandboxTemplate.resources` keeps its existing role as platform-native workload configuration. It may contain Kubernetes-style CPU, memory, and extended resource requests and limits, but it is not the portable resource contract. @@ -551,10 +562,10 @@ message DriverSandboxSpec { string log_level = 1; map environment = 5; DriverSandboxTemplate template = 6; - DriverSandboxResourceRequirements resource_requirements = 11; + DriverSandboxResourceRequirements resource_requirements = 9; - reserved 9, 10; - reserved "gpu", "gpu_device"; + reserved 10; + reserved "gpu_device"; } ``` @@ -562,6 +573,12 @@ Driver-owned resource requirement messages should have the same semantics as the public messages, but live in `compute_driver.proto` to keep the public and internal contracts separated. +The compute-driver API is version-coupled to the gateway in current deployments: +local drivers are launched by the gateway at startup, and the driver proto is +not treated as a public compatibility surface. It follows the same alpha-era +field replacement as the public API rather than preserving transitional GPU +fields. + ### Driver capabilities Replace GPU-specific capability fields with coarse resource capability