diff --git a/internal/gcs-sidecar/handlers.go b/internal/gcs-sidecar/handlers.go index f5a7c48d5e..37b2cff8e3 100644 --- a/internal/gcs-sidecar/handlers.go +++ b/internal/gcs-sidecar/handlers.go @@ -57,10 +57,9 @@ func (b *Bridge) createContainer(req *request) (err error) { return errors.Wrap(err, "failed to unmarshal createContainer") } - // containerConfig can be of type uvnConfig or hcsschema.HostedSystem or guestresource.CWCOWHostedSystem + // containerConfig can be of type uvmConfig or guestresource.CWCOWHostedSystem var ( uvmConfig prot.UvmConfig - hostedSystemConfig hcsschema.HostedSystem cwcowHostedSystemConfig guestresource.CWCOWHostedSystem ) if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &uvmConfig); err == nil && @@ -68,11 +67,6 @@ func (b *Bridge) createContainer(req *request) (err error) { systemType := uvmConfig.SystemType timeZoneInformation := uvmConfig.TimeZoneInformation log.G(ctx).Tracef("createContainer: uvmConfig: {systemType: %v, timeZoneInformation: %v}}", systemType, timeZoneInformation) - } else if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &hostedSystemConfig); err == nil && - hostedSystemConfig.SchemaVersion != nil && hostedSystemConfig.Container != nil { - schemaVersion := hostedSystemConfig.SchemaVersion - container := hostedSystemConfig.Container - log.G(ctx).Tracef("rpcCreate: HostedSystemConfig: {schemaVersion: %v, container: %v}}", schemaVersion, container) } else if err = commonutils.UnmarshalJSONWithHresult(containerConfig, &cwcowHostedSystemConfig); err == nil && cwcowHostedSystemConfig.Spec.Version != "" && cwcowHostedSystemConfig.CWCOWHostedSystem.Container != nil { cwcowHostedSystem := cwcowHostedSystemConfig.CWCOWHostedSystem @@ -536,7 +530,6 @@ func (b *Bridge) modifyServiceSettings(req *request) (err error) { defer span.End() defer func() { oc.SetSpanStatus(span, err) }() - // Todo: Add policy enforcement for modifying service settings modifyRequest, err := unmarshalModifyServiceSettings(req) if err != nil { return fmt.Errorf("failed to unmarshal modifyServiceSettings request: %w", err) @@ -551,28 +544,110 @@ func (b *Bridge) modifyServiceSettings(req *request) (err error) { switch settings.RPCType { case guestrequest.RPCModifyServiceSettings, guestrequest.RPCStartLogForwarding, guestrequest.RPCStopLogForwarding: log.G(req.ctx).Tracef("%v request received for LogForwardService, proceeding with policy enforcement for log sources", settings.RPCType) - // Enforce the policy for log sources in the request and update the settings with allowed log sources. - // For cwcow, the sidecar-GCS will verify the allowed log sources against policy and append the necessary GUIDs to the ones allowed. Rest are dropped. - // The Enforcer will have to unmarshal the log sources, enforce the policy and then marshal it back to a Base64 encoded JSON string which is what inbox GCS expects. - // It can query etw.GetDefaultLogSources to get the default log sources if the policy allows, and allow providers matching the default list during policy enforcement. - // This is because the log sources can be a combination of default and user specified log sources for which GUIDs need to be appended based on the policy enforcement. if settings.Settings != "" { - // - // allowedLogSources, err := b.hostState.securityOptions.PolicyEnforcer.EnforceLogForwardServiceSettingsPolicy(req.ctx, settings.LogSources) + // Decode the base64-encoded log sources config so we can + // enforce policy on the requested provider list. + logSources, err := etw.DecodeAndUnmarshalLogSources(settings.Settings) + if err != nil { + return fmt.Errorf("failed to decode log sources: %w", err) + } + + // Validate host-supplied (Name, GUID) pairs before + // name-based policy enforcement. + if err := validateLogProviders(logSources.LogConfig.Sources); err != nil { + return fmt.Errorf("log providers rejected: %w", err) + } - // For now, we are skipping the policy enforcement and allowing all log sources as the policy enforcer implementation is in progress. We will add the enforcement back once it's implemented. - allowedLogSources := settings.Settings // This is Base64 encoded JSON string of log sources - log.G(req.ctx).Tracef("Allowed log sources after policy enforcement: %v", allowedLogSources) + // Collect every requested provider name and ask the + // enforcer to validate them as a batch. The enforcer's + // behaviour depends on allow_log_provider_dropping in the + // active policy: + // - false (default, fail-close): any disallowed provider + // causes the call to be denied. + // - true: disallowed providers are silently dropped and + // the kept subset is returned for forwarding. + var requestedNames []string + for _, source := range logSources.LogConfig.Sources { + for _, provider := range source.Providers { + requestedNames = append(requestedNames, provider.ProviderName) + } + } - // Update the allowed log sources in the settings. This will be forwarded to inbox GCS which expects the log sources in a JSON string format with GUIDs for providers included. - allowedLogSources, err := etw.UpdateLogSources(allowedLogSources, false, true) + keptNames, err := b.hostState.securityOptions.PolicyEnforcer.EnforceLogProviderPolicy( + req.ctx, requestedNames) + if err != nil { + return fmt.Errorf("log providers denied by policy: %w", err) + } + + // Build a quick lookup for the kept set so we can trim the + // LogSourcesInfo to only those providers the policy allowed. + keepSet := make(map[string]struct{}, len(keptNames)) + for _, name := range keptNames { + keepSet[name] = struct{}{} + } + + // Detect trimming by scanning requested names against + // keepSet. We cannot use len(kept) != len(requested): + // the rego enforcer returns providers_to_keep via a set + // (see getProvidersToKeep → keepSet.toArray()), so a + // duplicate-name request like [A, A, B] returns [A, B] + // even when nothing was dropped, which would otherwise + // trip a false-positive warning and a needless re-marshal. + dropped := make([]string, 0) + seenDropped := make(map[string]struct{}) + for _, name := range requestedNames { + if _, ok := keepSet[name]; ok { + continue + } + if _, dup := seenDropped[name]; dup { + continue + } + seenDropped[name] = struct{}{} + dropped = append(dropped, name) + } + + // Trim happens in-place on the parsed structure so we can + // hand it to UpdateLogSourcesFromInfo without a redundant + // base64-decode + JSON-unmarshal round-trip (we already + // decoded above for enforcement). + trimmed := logSources + if len(dropped) > 0 { + // Surface the drop so operators have a breadcrumb — + // under allow_log_provider_dropping the pod boots + // silently, and forwardlogs may itself be off, so + // without this warning the trim is invisible. + log.G(req.ctx).WithFields(map[string]interface{}{ + "requested": requestedNames, + "kept": keptNames, + "dropped": dropped, + }).Warn("log providers trimmed by policy (allow_log_provider_dropping)") + + // Trim each source's provider list to only the + // allowed names. Empty sources are preserved to keep + // the shape stable; inbox GCS handles them as no-ops. + for i := range trimmed.LogConfig.Sources { + src := &trimmed.LogConfig.Sources[i] + filtered := make([]etw.EtwProvider, 0, len(src.Providers)) + for _, p := range src.Providers { + if _, ok := keepSet[p.ProviderName]; ok { + filtered = append(filtered, p) + } + } + src.Providers = filtered + } + } + + // Apply GUID resolution (and any other inbox-GCS prep) + // against the policy-trimmed payload and hand off to + // inbox GCS. + allowedLogSources, err := etw.UpdateLogSourcesFromInfo(trimmed, false, true) if err != nil { return fmt.Errorf("failed to update log sources: %w", err) } settings.Settings = allowedLogSources } default: - log.G(req.ctx).Warningf("modifyServiceSettings for LogForwardService with RPCType: %v, skipping policy enforcement", settings.RPCType) + return fmt.Errorf("modifyServiceSettings for LogForwardService: unsupported RPCType %q", settings.RPCType) } modifyRequest.Settings = settings buf, err := json.Marshal(modifyRequest) @@ -589,12 +664,59 @@ func (b *Bridge) modifyServiceSettings(req *request) (err error) { log.G(req.ctx).Warningf("modifyServiceSettings for LogForwardService with empty settings, skipping policy enforcement") } default: - log.G(req.ctx).Warningf("modifyServiceSettings with PropertyType: %v, skipping policy enforcement", modifyRequest.PropertyType) + return fmt.Errorf("modifyServiceSettings: unsupported PropertyType %q", modifyRequest.PropertyType) } b.forwardRequestToGcs(req) return nil } +// validateLogProviders validates host-supplied log providers before they +// reach the name-based policy enforcer. +// +// CWCOW policy approves provider names, but inbox GCS subscribes by GUID. If +// the host could send {Name: "allowed", GUID: ""} the name-based +// enforcer would approve and the disallowed GUID would still be forwarded +// (resolveGUIDsWithLookup keeps any GUID the host set). To close that bypass +// the sidecar rejects, before enforcement, any entry whose (Name, GUID) pair +// is not verifiable against the well-known ETW map: +// +// - Name == "": rejected. Policy is name-based; a GUID-only entry has +// nothing for the enforcer to evaluate. +// - Name + GUID where Name is not in the well-known map: rejected. We have +// no ground truth to compare the GUID against, so we cannot verify the +// host's claim. Name-only is still accepted for downstream resolution to +// stay best-effort. +// - Name + GUID where the GUID disagrees with the well-known lookup for +// Name: rejected. +// +// Name-only entries are passed through unchanged; the sidecar fills in the +// canonical GUID after enforcement via etw.UpdateLogSourcesFromInfo. +func validateLogProviders(sources []etw.Source) error { + for _, src := range sources { + for _, p := range src.Providers { + if p.ProviderName == "" { + return fmt.Errorf("provider with no name is not allowed (GUID %q)", p.ProviderGUID) + } + if p.ProviderGUID == "" { + continue + } + well := etw.GetProviderGUIDFromName(p.ProviderName) + if well == "" { + return fmt.Errorf("provider %q: name not in well-known ETW map; cannot verify supplied GUID %q", p.ProviderName, p.ProviderGUID) + } + suppliedTrimmed := strings.TrimSuffix(strings.TrimPrefix(strings.TrimSpace(p.ProviderGUID), "{"), "}") + supplied, err := guid.FromString(suppliedTrimmed) + if err != nil { + return fmt.Errorf("provider %q: invalid GUID %q: %w", p.ProviderName, p.ProviderGUID, err) + } + if !strings.EqualFold(supplied.String(), well) { + return fmt.Errorf("provider %q: supplied GUID %q does not match well-known GUID %q", p.ProviderName, p.ProviderGUID, well) + } + } + } + return nil +} + func volumeGUIDFromLayerPath(path string) (string, bool) { if p, ok := strings.CutPrefix(path, `\\?\Volume{`); ok { if q, ok := strings.CutSuffix(p, `}\Files`); ok { diff --git a/internal/gcs-sidecar/handlers_test.go b/internal/gcs-sidecar/handlers_test.go index 6de3a0a605..923a72845a 100644 --- a/internal/gcs-sidecar/handlers_test.go +++ b/internal/gcs-sidecar/handlers_test.go @@ -5,6 +5,7 @@ package bridge import ( "context" + "encoding/base64" "encoding/json" "io" "testing" @@ -14,7 +15,9 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs/prot" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/vm/vmutils/etw" "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/sirupsen/logrus" ) // buildModifySettingsRequest creates a serialized ModifySettings request message @@ -327,3 +330,498 @@ func TestModifySettings_PolicyFragment_TypeAssertionFailure(t *testing.T) { t.Fatal("expected error for empty fragment, got nil") } } + +// buildLogForwardServiceRequest builds a serialized ServiceModificationRequest +// for the LogForwardService with the given provider names baked into a +// base64-encoded LogSourcesInfo payload. +func buildLogForwardServiceRequest(t *testing.T, providerNames ...string) []byte { + t.Helper() + + providers := make([]etw.EtwProvider, 0, len(providerNames)) + for _, name := range providerNames { + providers = append(providers, etw.EtwProvider{ProviderName: name}) + } + info := etw.LogSourcesInfo{ + LogConfig: etw.LogConfig{ + Sources: []etw.Source{{ + Type: "etw", + Providers: providers, + }}, + }, + } + infoBytes, err := json.Marshal(info) + if err != nil { + t.Fatalf("failed to marshal log sources: %v", err) + } + encoded := base64.StdEncoding.EncodeToString(infoBytes) + + inner := &guestrequest.LogForwardServiceRPCRequest{ + RPCType: guestrequest.RPCModifyServiceSettings, + Settings: encoded, + } + req := prot.ServiceModificationRequest{ + RequestBase: prot.RequestBase{ + ContainerID: UVMContainerID, + ActivityID: guid.GUID{}, + }, + PropertyType: string(prot.LogForwardService), + Settings: inner, + } + b, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + return b +} + +// newModifyServiceSettingsRequest wraps the given LogForwardService payload +// in a bridge `request` ready for modifyServiceSettings. +func newModifyServiceSettingsRequest(payload []byte) *request { + return &request{ + ctx: context.Background(), + header: messageHeader{ + Type: prot.MsgTypeRequest | prot.MsgType(prot.RPCModifyServiceSettings), + Size: uint32(len(payload)) + prot.HdrSize, + ID: 1, + }, + activityID: guid.GUID{}, + message: payload, + } +} + +// TestModifyServiceSettings_LogForward_PolicyAllow_ForwardsToGCS verifies that +// when every requested provider is allowed by policy, the call succeeds and +// the (possibly GUID-resolved) request is forwarded to inbox GCS. +func TestModifyServiceSettings_LogForward_PolicyAllow_ForwardsToGCS(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + // Use a provider that is in the known etw_map so UpdateLogSources's GUID + // resolution succeeds. + payload := buildLogForwardServiceRequest(t, "microsoft.windows.hyperv.compute") + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err != nil { + t.Fatalf("modifyServiceSettings with allowed provider returned error: %v", err) + } + + select { + case <-b.sendToGCSCh: + // Forwarded to GCS as expected. + case <-time.After(time.Second): + t.Fatal("timed out waiting for request to be forwarded to GCS") + } +} + +// TestModifyServiceSettings_LogForward_PolicyDeny_ReturnsErrorAndDoesNotForward +// verifies that when any requested provider is denied by policy, the call +// fails and the request is not forwarded to inbox GCS. +func TestModifyServiceSettings_LogForward_PolicyDeny_ReturnsErrorAndDoesNotForward(t *testing.T) { + b := newTestBridge(&securitypolicy.ClosedDoorSecurityPolicyEnforcer{}) + + payload := buildLogForwardServiceRequest(t, "microsoft.windows.hyperv.compute") + req := newModifyServiceSettingsRequest(payload) + + err := b.modifyServiceSettings(req) + if err == nil { + t.Fatal("expected modifyServiceSettings to fail under ClosedDoor enforcer") + } + + // The request must NOT have been forwarded to GCS. + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("denied request must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// droppingLogProviderEnforcer is a test stub that approves only the configured +// allow-list of provider names; any others are silently dropped from the +// returned subset. It mirrors the regoEnforcer's behaviour under +// allow_log_provider_dropping := true and never returns an error. +type droppingLogProviderEnforcer struct { + securitypolicy.OpenDoorSecurityPolicyEnforcer + allowed map[string]struct{} +} + +func (e *droppingLogProviderEnforcer) EnforceLogProviderPolicy(_ context.Context, providerNames []string) ([]string, error) { + kept := make([]string, 0, len(providerNames)) + for _, name := range providerNames { + if _, ok := e.allowed[name]; ok { + kept = append(kept, name) + } + } + return kept, nil +} + +// TestModifyServiceSettings_LogForward_PolicyDropping_TrimsForwardedPayload +// verifies the silent-drop path in the sidecar: when the enforcer returns a +// strict subset of the requested providers, the call succeeds and the payload +// forwarded to inbox GCS contains only the kept providers (not the original +// disallowed ones). +func TestModifyServiceSettings_LogForward_PolicyDropping_TrimsForwardedPayload(t *testing.T) { + kept := "microsoft.windows.hyperv.compute" + dropped := "some-bogus-provider" + enforcer := &droppingLogProviderEnforcer{ + allowed: map[string]struct{}{kept: {}}, + } + b := newTestBridge(enforcer) + + payload := buildLogForwardServiceRequest(t, kept, dropped) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err != nil { + t.Fatalf("modifyServiceSettings under dropping enforcer returned error: %v", err) + } + + var forwarded request + select { + case forwarded = <-b.sendToGCSCh: + case <-time.After(time.Second): + t.Fatal("timed out waiting for request to be forwarded to GCS") + } + + // Decode the forwarded request back into LogSourcesInfo and confirm the + // disallowed provider has been stripped while the allowed one survives. + var fwdReq prot.ServiceModificationRequest + fwdReq.Settings = &guestrequest.LogForwardServiceRPCRequest{} + if err := json.Unmarshal(forwarded.message, &fwdReq); err != nil { + t.Fatalf("failed to unmarshal forwarded request: %v", err) + } + innerSettings, ok := fwdReq.Settings.(*guestrequest.LogForwardServiceRPCRequest) + if !ok { + t.Fatalf("forwarded settings has unexpected type: %T", fwdReq.Settings) + } + logSources, err := etw.DecodeAndUnmarshalLogSources(innerSettings.Settings) + if err != nil { + t.Fatalf("failed to decode forwarded log sources: %v", err) + } + + var sawKept, sawDropped bool + for _, src := range logSources.LogConfig.Sources { + for _, p := range src.Providers { + if p.ProviderName == kept { + sawKept = true + } + if p.ProviderName == dropped { + sawDropped = true + } + } + } + if !sawKept { + t.Errorf("expected forwarded payload to contain kept provider %q", kept) + } + if sawDropped { + t.Errorf("expected dropped provider %q to be absent from forwarded payload", dropped) + } +} + +// captureHook is a tiny logrus hook that records every entry it sees. +// Used by TestModifyServiceSettings_LogForward_PolicyDropping_NoFalsePositive +// to assert the "log providers trimmed by policy" Warn is *not* emitted when +// the only reason kept and requested differ is set-deduplication. +type captureHook struct { + entries []*logrus.Entry +} + +func (h *captureHook) Levels() []logrus.Level { return logrus.AllLevels } +func (h *captureHook) Fire(e *logrus.Entry) error { + h.entries = append(h.entries, e) + return nil +} + +// TestModifyServiceSettings_LogForward_PolicyDropping_NoFalsePositive guards +// against a false-positive trim warning + needless re-marshal when the +// enforcer returns a deduplicated set. The rego implementation builds +// providers_to_keep via a stringSet (see getProvidersToKeep), so a request +// with duplicate provider names like [A, A] comes back as [A] even when +// nothing was actually dropped. Detection must be based on "some requested +// name is missing from keepSet", not len(kept) != len(requested). +func TestModifyServiceSettings_LogForward_PolicyDropping_NoFalsePositive(t *testing.T) { + name := "microsoft.windows.hyperv.compute" + enforcer := &droppingLogProviderEnforcer{ + allowed: map[string]struct{}{name: {}}, + } + b := newTestBridge(enforcer) + + // Two copies of the same allowed provider. dedup in the enforcer means + // kept=[name] while requested=[name, name]; the lengths differ but the + // set of requested names is fully covered, so this is NOT a trim. + payload := buildLogForwardServiceRequest(t, name, name) + req := newModifyServiceSettingsRequest(payload) + + hook := &captureHook{} + logrus.AddHook(hook) + defer func() { + // logrus has no public RemoveHook; reset all hooks to clear ours. + logrus.StandardLogger().ReplaceHooks(logrus.LevelHooks{}) + }() + + if err := b.modifyServiceSettings(req); err != nil { + t.Fatalf("modifyServiceSettings under dropping enforcer (dedup) returned error: %v", err) + } + + // Must forward to GCS. + select { + case <-b.sendToGCSCh: + case <-time.After(time.Second): + t.Fatal("timed out waiting for request to be forwarded to GCS") + } + + // Must NOT have emitted the trim warning: nothing was actually dropped. + for _, e := range hook.entries { + if e.Level == logrus.WarnLevel && + e.Message == "log providers trimmed by policy (allow_log_provider_dropping)" { + t.Errorf("false-positive trim warning emitted on a dedup-only mismatch (kept=%v requested=%v dropped=%v)", + e.Data["kept"], e.Data["requested"], e.Data["dropped"]) + } + } +} + +// TestModifyServiceSettings_UnsupportedPropertyType_Denied verifies that a +// ModifyServiceSettings request whose PropertyType is not one the sidecar +// structurally understands is rejected and not forwarded to inbox GCS. +// +// An empty PropertyType is used because unmarshalModifyServiceSettings only +// validates non-empty PropertyType values, so this is the path that actually +// reaches the handler's outer switch default. +func TestModifyServiceSettings_UnsupportedPropertyType_Denied(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + r := prot.ServiceModificationRequest{ + RequestBase: prot.RequestBase{ + ContainerID: UVMContainerID, + ActivityID: guid.GUID{}, + }, + // PropertyType deliberately empty to exercise the handler's + // outer-switch default branch. + PropertyType: "", + } + payload, err := json.Marshal(r) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err == nil { + t.Fatal("expected modifyServiceSettings to fail for unsupported PropertyType") + } + + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("request with unsupported PropertyType must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// TestModifyServiceSettings_LogForward_UnsupportedRPCType_Denied verifies +// that a LogForwardService request carrying an RPCType the sidecar does not +// recognise is rejected and not forwarded to inbox GCS. +func TestModifyServiceSettings_LogForward_UnsupportedRPCType_Denied(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + inner := &guestrequest.LogForwardServiceRPCRequest{ + RPCType: guestrequest.RPCType("UnsupportedRPCType"), + } + r := prot.ServiceModificationRequest{ + RequestBase: prot.RequestBase{ + ContainerID: UVMContainerID, + ActivityID: guid.GUID{}, + }, + PropertyType: string(prot.LogForwardService), + Settings: inner, + } + payload, err := json.Marshal(r) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err == nil { + t.Fatal("expected modifyServiceSettings to fail for unsupported RPCType") + } + + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("request with unsupported RPCType must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// buildLogForwardServiceRequestWithProviders is the variant of +// buildLogForwardServiceRequest that lets each test set ProviderName and +// ProviderGUID independently, so the validateLogProviders tests can +// exercise mismatched and GUID-only payloads. +func buildLogForwardServiceRequestWithProviders(t *testing.T, providers []etw.EtwProvider) []byte { + t.Helper() + + info := etw.LogSourcesInfo{ + LogConfig: etw.LogConfig{ + Sources: []etw.Source{{ + Type: "etw", + Providers: providers, + }}, + }, + } + infoBytes, err := json.Marshal(info) + if err != nil { + t.Fatalf("failed to marshal log sources: %v", err) + } + encoded := base64.StdEncoding.EncodeToString(infoBytes) + + inner := &guestrequest.LogForwardServiceRPCRequest{ + RPCType: guestrequest.RPCModifyServiceSettings, + Settings: encoded, + } + req := prot.ServiceModificationRequest{ + RequestBase: prot.RequestBase{ + ContainerID: UVMContainerID, + ActivityID: guid.GUID{}, + }, + PropertyType: string(prot.LogForwardService), + Settings: inner, + } + b, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + return b +} + +// TestModifyServiceSettings_LogForward_GUIDOnly_Denied verifies that a +// provider entry with ProviderName=="" (GUID-only) is rejected before +// reaching policy enforcement. CWCOW policy is name-based, so a GUID-only +// entry has nothing for the enforcer to evaluate; accepting it would let the +// host smuggle a disallowed GUID past name-based policy. +func TestModifyServiceSettings_LogForward_GUIDOnly_Denied(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + payload := buildLogForwardServiceRequestWithProviders(t, []etw.EtwProvider{ + {ProviderName: "", ProviderGUID: "80ce50de-d264-4581-950d-abadeee0d340"}, + }) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err == nil { + t.Fatal("expected modifyServiceSettings to reject GUID-only provider entry") + } + + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("rejected request must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// TestModifyServiceSettings_LogForward_NameGUIDMismatch_Denied verifies that +// a provider entry whose ProviderGUID disagrees with the well-known map +// lookup for ProviderName is rejected. Without this check a hostile host +// could pair an allowed Name with a disallowed GUID and bypass name-based +// policy because inbox GCS subscribes by GUID. +func TestModifyServiceSettings_LogForward_NameGUIDMismatch_Denied(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + // Name resolves to 80ce50de-d264-4581-950d-abadeee0d340 in the + // well-known map; deliberately supply an unrelated valid GUID. + payload := buildLogForwardServiceRequestWithProviders(t, []etw.EtwProvider{ + { + ProviderName: "microsoft.windows.hyperv.compute", + ProviderGUID: "11111111-2222-3333-4444-555555555555", + }, + }) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err == nil { + t.Fatal("expected modifyServiceSettings to reject Name/GUID mismatch") + } + + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("rejected request must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// TestModifyServiceSettings_LogForward_UnknownNameWithGUID_Denied verifies +// that a provider entry whose ProviderName is not in the well-known ETW map +// is rejected when paired with a ProviderGUID: the sidecar has no ground +// truth to verify the host's claim against. +func TestModifyServiceSettings_LogForward_UnknownNameWithGUID_Denied(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + payload := buildLogForwardServiceRequestWithProviders(t, []etw.EtwProvider{ + { + ProviderName: "unknown-provider", + ProviderGUID: "11111111-2222-3333-4444-555555555555", + }, + }) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err == nil { + t.Fatal("expected modifyServiceSettings to reject unknown Name + GUID") + } + + select { + case fwd := <-b.sendToGCSCh: + t.Fatalf("rejected request must not be forwarded to GCS: %+v", fwd) + default: + // Good. + } +} + +// TestModifyServiceSettings_LogForward_NameMatchingGUID_Allowed verifies the +// positive path of validateLogProviders: a provider entry where +// ProviderGUID matches the well-known lookup for ProviderName passes +// validation and is forwarded to inbox GCS. +func TestModifyServiceSettings_LogForward_NameMatchingGUID_Allowed(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + payload := buildLogForwardServiceRequestWithProviders(t, []etw.EtwProvider{ + { + ProviderName: "microsoft.windows.hyperv.compute", + ProviderGUID: "80ce50de-d264-4581-950d-abadeee0d340", + }, + }) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err != nil { + t.Fatalf("modifyServiceSettings with matching Name/GUID returned error: %v", err) + } + + select { + case <-b.sendToGCSCh: + // Forwarded to GCS as expected. + case <-time.After(time.Second): + t.Fatal("timed out waiting for request to be forwarded to GCS") + } +} + +// TestModifyServiceSettings_LogForward_BracedGUID_Allowed verifies that the +// validator accepts GUID strings wrapped in `{...}` braces (the common +// canonical form on Windows). The well-known map stores the un-braced form, +// so the comparison must be brace-insensitive. +func TestModifyServiceSettings_LogForward_BracedGUID_Allowed(t *testing.T) { + b := newTestBridge(&securitypolicy.OpenDoorSecurityPolicyEnforcer{}) + + payload := buildLogForwardServiceRequestWithProviders(t, []etw.EtwProvider{ + { + ProviderName: "microsoft.windows.hyperv.compute", + ProviderGUID: "{80ce50de-d264-4581-950d-abadeee0d340}", + }, + }) + req := newModifyServiceSettingsRequest(payload) + + if err := b.modifyServiceSettings(req); err != nil { + t.Fatalf("modifyServiceSettings with braced matching GUID returned error: %v", err) + } + + select { + case <-b.sendToGCSCh: + // Forwarded to GCS as expected. + case <-time.After(time.Second): + t.Fatal("timed out waiting for request to be forwarded to GCS") + } +} diff --git a/internal/tools/securitypolicy/main.go b/internal/tools/securitypolicy/main.go index d5ad89e28b..be1860959d 100644 --- a/internal/tools/securitypolicy/main.go +++ b/internal/tools/securitypolicy/main.go @@ -68,6 +68,7 @@ func main() { config.AllowEnvironmentVariableDropping, config.AllowUnencryptedScratch, config.AllowCapabilityDropping, + config.AllowLogProviderDropping, ) } if err != nil { diff --git a/internal/uvm/start.go b/internal/uvm/start.go index 300cd63b9d..89d7ae475b 100644 --- a/internal/uvm/start.go +++ b/internal/uvm/start.go @@ -292,11 +292,28 @@ func (uvm *UtilityVM) Start(ctx context.Context) (err error) { if uvm.OS() == "windows" && uvm.logForwardingEnabled { // If the UVM is Windows and log forwarding is enabled, set the log sources // and start the log forwarding service. + // + // For confidential (CWCOW) UVMs, a failure here may be a policy + // violation (e.g. log_provider denied by the rego). In that case we + // must fail-close: return the error so the deferred + // uvm.hcsSystem.Terminate above tears the UVM down rather than + // leaving it half-initialised with a known-violating log + // configuration. For non-confidential WCOW, log-forwarding is + // best-effort (transient ETW issues, missing providers, etc. must + // not abort UVM start), so preserve the original log-and-continue + // behaviour. + failClose := uvm.HasConfidentialPolicy() if err := uvm.SetLogSources(ctx); err != nil { e.WithError(err).Error("failed to set log sources") + if failClose { + return fmt.Errorf("failed to set log sources: %w", err) + } } if err := uvm.StartLogForwarding(ctx); err != nil { e.WithError(err).Error("failed to start log forwarding") + if failClose { + return fmt.Errorf("failed to start log forwarding: %w", err) + } } } diff --git a/internal/vm/vmutils/etw/provider_map.go b/internal/vm/vmutils/etw/provider_map.go index 5b35206602..9a37022c2a 100644 --- a/internal/vm/vmutils/etw/provider_map.go +++ b/internal/vm/vmutils/etw/provider_map.go @@ -36,7 +36,7 @@ func GetDefaultLogSources() LogSourcesInfo { } // GetProviderGUIDFromName returns the provider GUID for a given provider name. If the provider name is not found in the map, it returns an empty string. -func getProviderGUIDFromName(providerName string) string { +func GetProviderGUIDFromName(providerName string) string { if guid, ok := etwNameToGUIDMap[strings.ToLower(providerName)]; ok { return guid } @@ -92,8 +92,8 @@ func mergeLogSources(resultSources []Source, userSources []Source) []Source { return resultSources } -// decodeAndUnmarshalLogSources decodes a base64-encoded JSON string and unmarshals it into a LogSourcesInfo. -func decodeAndUnmarshalLogSources(base64EncodedJSONLogConfig string) (LogSourcesInfo, error) { +// DecodeAndUnmarshalLogSources decodes a base64-encoded JSON string and unmarshals it into a LogSourcesInfo. +func DecodeAndUnmarshalLogSources(base64EncodedJSONLogConfig string) (LogSourcesInfo, error) { jsonBytes, err := base64.StdEncoding.DecodeString(base64EncodedJSONLogConfig) if err != nil { return LogSourcesInfo{}, fmt.Errorf("error decoding base64 log config: %w", err) @@ -127,7 +127,7 @@ func resolveGUIDsWithLookup(sources []Source) ([]Source, error) { sources[i].Providers[j].ProviderGUID = strings.ToLower(guid.String()) } if provider.ProviderName != "" && provider.ProviderGUID == "" { - sources[i].Providers[j].ProviderGUID = getProviderGUIDFromName(provider.ProviderName) + sources[i].Providers[j].ProviderGUID = GetProviderGUIDFromName(provider.ProviderName) } } } @@ -147,7 +147,7 @@ func stripRedundantGUIDs(sources []Source) ([]Source, error) { if err != nil { return nil, fmt.Errorf("invalid GUID %q for provider %q: %w", provider.ProviderGUID, provider.ProviderName, err) } - if strings.EqualFold(guid.String(), getProviderGUIDFromName(provider.ProviderName)) { + if strings.EqualFold(guid.String(), GetProviderGUIDFromName(provider.ProviderName)) { sources[i].Providers[j].ProviderGUID = "" } else { // If the GUID doesn't match the well-known GUID for the provider name, @@ -188,19 +188,29 @@ func marshalAndEncodeLogSources(logCfg LogSourcesInfo) (string, error) { // configuration and returns the updated log sources as a base64 encoded JSON string. // If there is an error in the process, it returns the original user provided log sources string. func UpdateLogSources(base64EncodedJSONLogConfig string, useDefaultLogSources bool, includeGUIDs bool) (string, error) { - var resultLogCfg LogSourcesInfo - if useDefaultLogSources { - resultLogCfg = defaultLogSourcesInfo - } - + var userLogSources LogSourcesInfo if base64EncodedJSONLogConfig != "" { - userLogSources, err := decodeAndUnmarshalLogSources(base64EncodedJSONLogConfig) + var err error + userLogSources, err = DecodeAndUnmarshalLogSources(base64EncodedJSONLogConfig) if err != nil { return "", fmt.Errorf("failed to decode and unmarshal user log sources: %w", err) } - resultLogCfg.LogConfig.Sources = mergeLogSources(resultLogCfg.LogConfig.Sources, userLogSources.LogConfig.Sources) + } + return UpdateLogSourcesFromInfo(userLogSources, useDefaultLogSources, includeGUIDs) +} +// UpdateLogSourcesFromInfo is the parsed-input variant of UpdateLogSources for +// callers that already have a LogSourcesInfo in hand (e.g. the gcs-sidecar +// after it decoded the payload for policy enforcement) and want to avoid a +// second base64-decode + JSON-unmarshal round-trip. +// +// An empty userLogSources is equivalent to passing "" to UpdateLogSources. +func UpdateLogSourcesFromInfo(userLogSources LogSourcesInfo, useDefaultLogSources bool, includeGUIDs bool) (string, error) { + var resultLogCfg LogSourcesInfo + if useDefaultLogSources { + resultLogCfg = defaultLogSourcesInfo } + resultLogCfg.LogConfig.Sources = mergeLogSources(resultLogCfg.LogConfig.Sources, userLogSources.LogConfig.Sources) var err error resultLogCfg.LogConfig.Sources, err = applyGUIDPolicy(resultLogCfg.LogConfig.Sources, includeGUIDs) diff --git a/internal/vm/vmutils/etw/provider_map_test.go b/internal/vm/vmutils/etw/provider_map_test.go index 4aa62d861c..55ce1b4ce3 100644 --- a/internal/vm/vmutils/etw/provider_map_test.go +++ b/internal/vm/vmutils/etw/provider_map_test.go @@ -23,9 +23,9 @@ func TestGetProviderGUIDFromName(t *testing.T) { } for _, tt := range tests { - got := getProviderGUIDFromName(tt.name) + got := GetProviderGUIDFromName(tt.name) if got != tt.expected { - t.Errorf("getProviderGUIDFromName(%q) = %q, want %q", tt.name, got, tt.expected) + t.Errorf("GetProviderGUIDFromName(%q) = %q, want %q", tt.name, got, tt.expected) } } } @@ -126,11 +126,11 @@ func buildTestUserLogSources(t *testing.T) LogSourcesInfo { nameOnlyProvider := "Microsoft.Windows.HyperV.Compute" nameAndGUIDProvider := "Microsoft.Windows.Containers.Setup" - guid := getProviderGUIDFromName(nameAndGUIDProvider) + guid := GetProviderGUIDFromName(nameAndGUIDProvider) if guid == "" { t.Fatalf("missing GUID mapping for provider %q", nameAndGUIDProvider) } - if getProviderGUIDFromName(nameOnlyProvider) == "" { + if GetProviderGUIDFromName(nameOnlyProvider) == "" { t.Fatalf("missing GUID mapping for provider %q", nameOnlyProvider) } @@ -185,7 +185,7 @@ func applyExpectedGUIDBehavior(cfg *LogSourcesInfo, includeGUIDs bool) { } } if provider.ProviderName != "" && provider.ProviderGUID == "" { - cfg.LogConfig.Sources[i].Providers[j].ProviderGUID = getProviderGUIDFromName(provider.ProviderName) + cfg.LogConfig.Sources[i].Providers[j].ProviderGUID = GetProviderGUIDFromName(provider.ProviderName) } continue } @@ -195,7 +195,7 @@ func applyExpectedGUIDBehavior(cfg *LogSourcesInfo, includeGUIDs bool) { if err != nil { continue } - if strings.EqualFold(guid.String(), getProviderGUIDFromName(provider.ProviderName)) { + if strings.EqualFold(guid.String(), GetProviderGUIDFromName(provider.ProviderName)) { cfg.LogConfig.Sources[i].Providers[j].ProviderGUID = "" } else { cfg.LogConfig.Sources[i].Providers[j].ProviderGUID = strings.ToLower(guid.String()) diff --git a/pkg/securitypolicy/api.rego b/pkg/securitypolicy/api.rego index 88c3d64d14..9474c75bce 100644 --- a/pkg/securitypolicy/api.rego +++ b/pkg/securitypolicy/api.rego @@ -24,4 +24,5 @@ enforcement_points := { "load_fragment": {"introducedVersion": "0.9.0", "default_results": {"allowed": false, "add_module": false}, "use_framework": false}, "scratch_mount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, "scratch_unmount": {"introducedVersion": "0.10.0", "default_results": {"allowed": true}, "use_framework": false}, + "log_provider": {"introducedVersion": "0.11.0", "default_results": {"allowed": true, "providers_to_keep": null}, "use_framework": false}, } diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index 76e5b048a0..1f252b02d1 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -1333,6 +1333,47 @@ scratch_unmount := {"metadata": [remove_scratch_mount], "allowed": true} { } } +# Log provider validation for Windows containers. +# +# Two modes (mirrors allow_environment_variable_dropping): +# - allow_log_provider_dropping := false (default, fail-close): every +# requested provider name must appear in allowed_log_providers, otherwise +# the rule denies the entire request. +# - allow_log_provider_dropping := true: providers not in the allow-list are +# silently dropped from providers_to_keep; the call still allows and the +# caller is expected to only forward the remaining providers. +# +# Input: {"providers": [name, ...]} +# Output: {"allowed": bool, "providers_to_keep": [name, ...]} +default log_provider := {"allowed": false, "providers_to_keep": []} + +valid_log_providers := providers { + allow_log_provider_dropping + + providers := [name | + name := input.providers[_] + some allowed_provider in data.policy.allowed_log_providers + lower(name) == lower(allowed_provider) + ] +} + +valid_log_providers := providers { + not allow_log_provider_dropping + providers := input.providers +} + +log_providers_ok(providers) { + every name in providers { + some allowed_provider in data.policy.allowed_log_providers + lower(name) == lower(allowed_provider) + } +} + +log_provider := {"allowed": true, "providers_to_keep": providers} { + providers := valid_log_providers + log_providers_ok(providers) +} + # Registry changes validation default registry_changes := {"allowed": false} @@ -1861,6 +1902,11 @@ errors["no scratch at path to unmount"] { not scratch_mounted(input.unmountTarget) } +errors["log provider not allowed by policy"] { + input.rule == "log_provider" + not log_provider.allowed +} + errors[framework_version_error] { policy_framework_version == null framework_version_error := concat(" ", ["framework_version is missing. Current version:", version]) @@ -2341,6 +2387,7 @@ allow_dump_stacks := data.policy.allow_dump_stacks allow_runtime_logging := data.policy.allow_runtime_logging allow_environment_variable_dropping := data.policy.allow_environment_variable_dropping allow_unencrypted_scratch := data.policy.allow_unencrypted_scratch +allow_log_provider_dropping := data.policy.allow_log_provider_dropping # all flags not in the base set need to have default logic applied diff --git a/pkg/securitypolicy/open_door.rego b/pkg/securitypolicy/open_door.rego index 02da3fa9b6..fa277d9e7e 100644 --- a/pkg/securitypolicy/open_door.rego +++ b/pkg/securitypolicy/open_door.rego @@ -23,3 +23,4 @@ runtime_logging := {"allowed": true} load_fragment := {"allowed": true} scratch_mount := {"allowed": true} scratch_unmount := {"allowed": true} +log_provider := {"allowed": true} diff --git a/pkg/securitypolicy/opts.go b/pkg/securitypolicy/opts.go index a11685abc4..9446f9dd30 100644 --- a/pkg/securitypolicy/opts.go +++ b/pkg/securitypolicy/opts.go @@ -115,6 +115,13 @@ func WithAllowEnvVarDropping(allow bool) PolicyConfigOpt { } } +func WithAllowLogProviderDropping(allow bool) PolicyConfigOpt { + return func(config *PolicyConfig) error { + config.AllowLogProviderDropping = allow + return nil + } +} + func WithAllowCapabilityDropping(allow bool) PolicyConfigOpt { return func(config *PolicyConfig) error { config.AllowCapabilityDropping = allow diff --git a/pkg/securitypolicy/policy.rego b/pkg/securitypolicy/policy.rego index 195d462931..2702075305 100644 --- a/pkg/securitypolicy/policy.rego +++ b/pkg/securitypolicy/policy.rego @@ -26,4 +26,5 @@ runtime_logging := data.framework.runtime_logging load_fragment := data.framework.load_fragment scratch_mount := data.framework.scratch_mount scratch_unmount := data.framework.scratch_unmount +log_provider := data.framework.log_provider reason := data.framework.reason diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index 6afe7de6cc..965bdcf81f 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -2020,6 +2020,7 @@ func (constraints *generatedConstraints) toPolicy() *securityPolicyInternal { AllowEnvironmentVariableDropping: constraints.allowEnvironmentVariableDropping, AllowUnencryptedScratch: constraints.allowUnencryptedScratch, AllowCapabilityDropping: constraints.allowCapabilityDropping, + AllowLogProviderDropping: constraints.allowLogProviderDropping, } } @@ -2281,6 +2282,7 @@ func generateConstraints(r *rand.Rand, maxContainers int32) *generatedConstraint namespace: generateFragmentNamespace(testRand), svn: generateSVN(testRand), allowCapabilityDropping: false, + allowLogProviderDropping: false, ctx: context.Background(), } } @@ -2920,6 +2922,7 @@ type generatedConstraints struct { namespace string svn string allowCapabilityDropping bool + allowLogProviderDropping bool ctx context.Context } @@ -2935,6 +2938,7 @@ type generatedWindowsConstraints struct { namespace string svn string allowCapabilityDropping bool + allowLogProviderDropping bool ctx context.Context } @@ -2949,6 +2953,7 @@ func (constraints *generatedWindowsConstraints) toPolicy() *securityPolicyWindow AllowEnvironmentVariableDropping: constraints.allowEnvironmentVariableDropping, AllowUnencryptedScratch: constraints.allowUnencryptedScratch, AllowCapabilityDropping: constraints.allowCapabilityDropping, + AllowLogProviderDropping: constraints.allowLogProviderDropping, } } @@ -2993,6 +2998,7 @@ func generateWindowsConstraints(r *rand.Rand, maxContainers int32) *generatedWin allowEnvironmentVariableDropping: false, allowUnencryptedScratch: false, allowCapabilityDropping: false, + allowLogProviderDropping: false, namespace: generateFragmentNamespace(r), svn: generateSVN(r), ctx: context.Background(), diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 51da87a18c..7b9080c815 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -75,6 +75,7 @@ func Test_MarshalRego_Policy(t *testing.T) { p.allowEnvironmentVariableDropping, p.allowUnencryptedScratch, p.allowCapabilityDropping, + p.allowLogProviderDropping, ) if err != nil { t.Error(err) diff --git a/pkg/securitypolicy/regopolicy_windows_test.go b/pkg/securitypolicy/regopolicy_windows_test.go index 33b49a64f8..9f9a9bb0b6 100644 --- a/pkg/securitypolicy/regopolicy_windows_test.go +++ b/pkg/securitypolicy/regopolicy_windows_test.go @@ -1514,3 +1514,212 @@ func substituteUVMPath(sandboxID string, m mountInternal) mountInternal { _ = sandboxID return m } + +// Tests for log provider enforcement + +// newLogProviderTestPolicy builds a Rego policy whose allowed_log_providers +// list contains the given providers and returns the compiled enforcer. +// Pass no providers to get an empty allow-list. +// +// allow_log_provider_dropping is left unset so the test exercises the +// default fail-close mode. Use newLogProviderTestPolicyWithDropping to flip +// the mode. +func newLogProviderTestPolicy(t *testing.T, allowedProviders ...string) *regoEnforcer { + t.Helper() + return newLogProviderTestPolicyWithDropping(t, false, allowedProviders...) +} + +// newLogProviderTestPolicyWithDropping is the more general helper used by the +// mode-specific tests. It compiles a Rego policy that defines +// allowed_log_providers, sets allow_log_provider_dropping to dropping, and +// routes log_provider through the framework rule. +func newLogProviderTestPolicyWithDropping(t *testing.T, dropping bool, allowedProviders ...string) *regoEnforcer { + t.Helper() + var listLines string + for _, p := range allowedProviders { + listLines += fmt.Sprintf("\t\t%q,\n", p) + } + rego := fmt.Sprintf(`package policy + api_version := "%s" + framework_version := "%s" + + allow_log_provider_dropping := %t + + allowed_log_providers := [ +%s ] + + log_provider := data.framework.log_provider + `, apiVersion, frameworkVersion, dropping, listLines) + + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("failed to create policy: %v", err) + } + return policy +} + +func Test_Rego_EnforceLogProviderPolicy_Allowed_Windows(t *testing.T) { + policy := newLogProviderTestPolicy(t, + "microsoft.windows.hyperv.compute", + "microsoft-windows-guest-network-service", + ) + + kept, err := policy.EnforceLogProviderPolicy(context.Background(), + []string{"microsoft.windows.hyperv.compute"}) + if err != nil { + t.Errorf("expected allowed provider to pass: %v", err) + } + if len(kept) != 1 || kept[0] != "microsoft.windows.hyperv.compute" { + t.Errorf("expected kept=[microsoft.windows.hyperv.compute]; got %v", kept) + } +} + +func Test_Rego_EnforceLogProviderPolicy_Denied_Windows(t *testing.T) { + policy := newLogProviderTestPolicy(t, "microsoft.windows.hyperv.compute") + + _, err := policy.EnforceLogProviderPolicy(context.Background(), + []string{"some-malicious-provider"}) + if err == nil { + t.Errorf("expected unknown provider to be denied") + } +} + +func Test_Rego_EnforceLogProviderPolicy_CaseInsensitive_Windows(t *testing.T) { + policy := newLogProviderTestPolicy(t, "microsoft.windows.hyperv.compute") + + kept, err := policy.EnforceLogProviderPolicy(context.Background(), + []string{"Microsoft.Windows.Hyperv.Compute"}) + if err != nil { + t.Errorf("expected case-insensitive match to pass: %v", err) + } + // Rego preserves the input casing; we just confirm the name survived. + if len(kept) != 1 || kept[0] != "Microsoft.Windows.Hyperv.Compute" { + t.Errorf("expected kept=[Microsoft.Windows.Hyperv.Compute]; got %v", kept) + } +} + +func Test_Rego_EnforceLogProviderPolicy_OpenDoor_AllowsAll_Windows(t *testing.T) { + policy, err := newRegoPolicy(openDoorRego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("failed to create policy: %v", err) + } + + ctx := context.Background() + kept, err := policy.EnforceLogProviderPolicy(ctx, []string{"any-provider-at-all"}) + if err != nil { + t.Errorf("open door should allow any provider: %v", err) + } + if len(kept) != 1 || kept[0] != "any-provider-at-all" { + t.Errorf("open door should keep the requested provider; got %v", kept) + } +} + +func Test_Rego_EnforceLogProviderPolicy_EmptyAllowList_DeniesAll_Windows(t *testing.T) { + policy := newLogProviderTestPolicy(t) + + _, err := policy.EnforceLogProviderPolicy(context.Background(), + []string{"microsoft.windows.hyperv.compute"}) + if err == nil { + t.Errorf("expected empty allow list to deny all providers") + } +} + +// Test_Rego_EnforceLogProviderPolicy_PreFeatureAPIVersion_Allows_Windows pins +// the non-regression behaviour for policies authored before log_provider was +// introduced (api.rego entry: introducedVersion=0.11.0, default_results.allowed=true). +// Such policies omit allowed_log_providers entirely; EnforceLogProviderPolicy +// must return the input list unchanged with no error so existing CWCOW/WCOW +// policies do not break when the framework gains the new enforcement point. +func Test_Rego_EnforceLogProviderPolicy_PreFeatureAPIVersion_Allows_Windows(t *testing.T) { + // A pre-feature policy does not define log_provider at all (it was + // authored before the rule existed). The version-gated default_results + // path only fires when the policy has no rule for the enforcement + // point — including `log_provider := data.framework.log_provider` + // here would shadow the default and route through the framework rule, + // which defaults to deny. + rego := fmt.Sprintf(`package policy + api_version := "0.10.0" + framework_version := "%s" + `, frameworkVersion) + + policy, err := newRegoPolicy(rego, []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Fatalf("failed to create policy: %v", err) + } + + ctx := context.Background() + kept, err := policy.EnforceLogProviderPolicy(ctx, + []string{"any-provider-not-in-any-list"}) + if err != nil { + t.Errorf("expected pre-0.11.0 policy to allow any provider via default_results: %v", err) + } + // default_results.providers_to_keep is null, so getProvidersToKeep + // returns the input list unchanged. + if len(kept) != 1 || kept[0] != "any-provider-not-in-any-list" { + t.Errorf("expected kept=[any-provider-not-in-any-list]; got %v", kept) + } +} + +// Test_Rego_EnforceLogProviderPolicy_EmptyProviderName_Denied_Windows pins +// the behaviour for the host-scrubbed-unknown-provider edge case: when the +// providerName is the empty string, no allow-list entry can match (allow-lists +// never contain ""), so enforcement must deny. +func Test_Rego_EnforceLogProviderPolicy_EmptyProviderName_Denied_Windows(t *testing.T) { + policy := newLogProviderTestPolicy(t, "microsoft.windows.hyperv.compute") + + _, err := policy.EnforceLogProviderPolicy(context.Background(), []string{""}) + if err == nil { + t.Errorf("expected empty providerName to be denied") + } +} + +// Test_Rego_EnforceLogProviderPolicy_Dropping_KeepsSubset_Windows exercises +// the silent-drop mode: when allow_log_provider_dropping := true the call +// allows even if some requested providers are not on the allow-list, and only +// the matching subset is returned. +func Test_Rego_EnforceLogProviderPolicy_Dropping_KeepsSubset_Windows(t *testing.T) { + policy := newLogProviderTestPolicyWithDropping(t, true, + "microsoft.windows.hyperv.compute", + "microsoft-windows-guest-network-service", + ) + + kept, err := policy.EnforceLogProviderPolicy(context.Background(), []string{ + "microsoft.windows.hyperv.compute", + "some-bogus-provider", + "microsoft-windows-guest-network-service", + }) + if err != nil { + t.Errorf("dropping mode should allow regardless of unknown providers: %v", err) + } + + keptSet := make(map[string]struct{}, len(kept)) + for _, n := range kept { + keptSet[n] = struct{}{} + } + if _, ok := keptSet["microsoft.windows.hyperv.compute"]; !ok { + t.Errorf("expected 'microsoft.windows.hyperv.compute' kept; got %v", kept) + } + if _, ok := keptSet["microsoft-windows-guest-network-service"]; !ok { + t.Errorf("expected 'microsoft-windows-guest-network-service' kept; got %v", kept) + } + if _, ok := keptSet["some-bogus-provider"]; ok { + t.Errorf("expected 'some-bogus-provider' to be dropped; got %v", kept) + } +} + +// Test_Rego_EnforceLogProviderPolicy_FailClose_AnyMissDenies_Windows confirms +// the inverse of the dropping test: with allow_log_provider_dropping := false +// (default) even one unknown provider in a batch fails the entire call. +func Test_Rego_EnforceLogProviderPolicy_FailClose_AnyMissDenies_Windows(t *testing.T) { + policy := newLogProviderTestPolicyWithDropping(t, false, + "microsoft.windows.hyperv.compute", + ) + + _, err := policy.EnforceLogProviderPolicy(context.Background(), []string{ + "microsoft.windows.hyperv.compute", + "some-bogus-provider", + }) + if err == nil { + t.Errorf("fail-close mode should deny when any provider is unknown") + } +} diff --git a/pkg/securitypolicy/securitypolicy.go b/pkg/securitypolicy/securitypolicy.go index f1d761d439..47ca1860bf 100644 --- a/pkg/securitypolicy/securitypolicy.go +++ b/pkg/securitypolicy/securitypolicy.go @@ -68,6 +68,12 @@ type PolicyConfig struct { // all containers within a pod to be run without scratch encryption. AllowUnencryptedScratch bool `json:"allow_unencrypted_scratch" toml:"allow_unencrypted_scratch"` AllowCapabilityDropping bool `json:"allow_capability_dropping" toml:"allow_capability_dropping"` + // AllowLogProviderDropping controls how EnforceLogProviderPolicy handles + // requested ETW providers that are not on the allow-list. When false + // (default, fail-close) any disallowed provider causes the entire + // modifyServiceSettings call to be denied. When true, disallowed providers + // are silently dropped and the call continues with the kept subset. + AllowLogProviderDropping bool `json:"allow_log_provider_dropping" toml:"allow_log_provider_dropping"` } func NewPolicyConfig(opts ...PolicyConfigOpt) (*PolicyConfig, error) { diff --git a/pkg/securitypolicy/securitypolicy_internal.go b/pkg/securitypolicy/securitypolicy_internal.go index c736fb58ed..42c10dbce0 100644 --- a/pkg/securitypolicy/securitypolicy_internal.go +++ b/pkg/securitypolicy/securitypolicy_internal.go @@ -19,6 +19,7 @@ type securityPolicyInternal struct { AllowEnvironmentVariableDropping bool AllowUnencryptedScratch bool AllowCapabilityDropping bool + AllowLogProviderDropping bool } // Internal version of Windows SecurityPolicy @@ -32,6 +33,7 @@ type securityPolicyWindowsInternal struct { AllowEnvironmentVariableDropping bool AllowUnencryptedScratch bool AllowCapabilityDropping bool + AllowLogProviderDropping bool } type securityPolicyFragment struct { @@ -97,6 +99,7 @@ func newSecurityPolicyInternal( allowDropEnvironmentVariables bool, allowUnencryptedScratch bool, allowDropCapabilities bool, + allowLogProviderDropping bool, ) (*securityPolicyInternal, error) { containersInternal, err := containersToInternal(containers) if err != nil { @@ -113,6 +116,7 @@ func newSecurityPolicyInternal( AllowEnvironmentVariableDropping: allowDropEnvironmentVariables, AllowUnencryptedScratch: allowUnencryptedScratch, AllowCapabilityDropping: allowDropCapabilities, + AllowLogProviderDropping: allowLogProviderDropping, }, nil } diff --git a/pkg/securitypolicy/securitypolicy_marshal.go b/pkg/securitypolicy/securitypolicy_marshal.go index 665dc9e4f0..9db7e3f19d 100644 --- a/pkg/securitypolicy/securitypolicy_marshal.go +++ b/pkg/securitypolicy/securitypolicy_marshal.go @@ -67,6 +67,7 @@ type OSAwareMarshalFunc func( allowEnvironmentVariableDropping bool, allowUnencryptedScratch bool, allowCapabilityDropping bool, + allowLogProviderDropping bool, ) (string, error) // osAwareMarshalRego handles both Linux and Windows containers @@ -83,6 +84,7 @@ func osAwareMarshalRego( allowEnvironmentVariableDropping bool, allowUnencryptedScratch bool, allowCapabilityDropping bool, + allowLogProviderDropping bool, ) (string, error) { if allowAll { if len(linuxContainers) > 0 || len(windowsContainers) > 0 { @@ -98,7 +100,8 @@ func osAwareMarshalRego( } return marshalRego(allowAll, linuxContainers, externalProcesses, fragments, allowPropertiesAccess, allowDumpStacks, allowRuntimeLogging, - allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapabilityDropping) + allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapabilityDropping, + allowLogProviderDropping) case "windows": if len(linuxContainers) > 0 { @@ -106,7 +109,8 @@ func osAwareMarshalRego( } return marshalWindowsRego(allowAll, windowsContainers, externalProcesses, fragments, allowPropertiesAccess, allowDumpStacks, allowRuntimeLogging, - allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapabilityDropping) + allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapabilityDropping, + allowLogProviderDropping) default: return "", fmt.Errorf("unsupported OS type: %s", osType) @@ -125,6 +129,7 @@ func marshalWindowsRego( allowEnvironmentVariableDropping bool, allowUnencryptedScratch bool, allowCapabilityDropping bool, + allowLogProviderDropping bool, ) (string, error) { if allowAll { if len(containers) > 0 { @@ -149,6 +154,7 @@ func marshalWindowsRego( AllowEnvironmentVariableDropping: allowEnvironmentVariableDropping, AllowUnencryptedScratch: allowUnencryptedScratch, AllowCapabilityDropping: allowCapabilityDropping, + AllowLogProviderDropping: allowLogProviderDropping, } return policy.marshalWindowsRego(), nil @@ -167,6 +173,7 @@ func marshalJSON( _ bool, _ bool, _ bool, + _ bool, ) (string, error) { var policy *SecurityPolicy if allowAll { @@ -198,6 +205,7 @@ func marshalRego( allowEnvironmentVariableDropping bool, allowUnencryptedScratch bool, allowCapabilityDropping bool, + allowLogProviderDropping bool, ) (string, error) { if allowAll { if len(containers) > 0 { @@ -217,6 +225,7 @@ func marshalRego( allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapabilityDropping, + allowLogProviderDropping, ) if err != nil { return "", err @@ -251,6 +260,7 @@ func MarshalPolicy( allowEnvironmentVariableDropping bool, allowUnencryptedScratch bool, allowCapbilitiesDropping bool, + allowLogProviderDropping bool, ) (string, error) { if marshaller == "" { marshaller = defaultMarshaller @@ -272,6 +282,7 @@ func MarshalPolicy( allowEnvironmentVariableDropping, allowUnencryptedScratch, allowCapbilitiesDropping, + allowLogProviderDropping, ) } } @@ -596,6 +607,7 @@ func (p securityPolicyInternal) marshalRego() string { writeLine(builder, "allow_environment_variable_dropping := %t", p.AllowEnvironmentVariableDropping) writeLine(builder, "allow_unencrypted_scratch := %t", p.AllowUnencryptedScratch) writeLine(builder, "allow_capability_dropping := %t", p.AllowCapabilityDropping) + writeLine(builder, "allow_log_provider_dropping := %t", p.AllowLogProviderDropping) result := strings.Replace(policyRegoTemplate, "@@OBJECTS@@", builder.String(), 1) result = strings.Replace(result, "@@API_VERSION@@", apiVersion, 1) result = strings.Replace(result, "@@FRAMEWORK_VERSION@@", frameworkVersion, 1) @@ -621,6 +633,7 @@ func (p securityPolicyWindowsInternal) marshalWindowsRego() string { writeLine(builder, "allow_environment_variable_dropping := %t", p.AllowEnvironmentVariableDropping) writeLine(builder, "allow_unencrypted_scratch := %t", p.AllowUnencryptedScratch) writeLine(builder, "allow_capability_dropping := %t", p.AllowCapabilityDropping) + writeLine(builder, "allow_log_provider_dropping := %t", p.AllowLogProviderDropping) result := strings.Replace(policyRegoTemplate, "@@OBJECTS@@", builder.String(), 1) result = strings.Replace(result, "@@API_VERSION@@", apiVersion, 1) result = strings.Replace(result, "@@FRAMEWORK_VERSION@@", frameworkVersion, 1) diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 0c2a98e998..bb13b27237 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -131,6 +131,14 @@ type SecurityPolicyEnforcer interface { GetUserInfo(spec *oci.Process, rootPath string) (IDName, []IDName, string, error) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string, mountedCim []string) (err error) EnforceRegistryChangesPolicy(ctx context.Context, containerID string, registryValues interface{}) error + // EnforceLogProviderPolicy validates a batch of requested ETW provider + // names against the policy's allowed_log_providers list. It returns the + // subset of provider names that the caller should forward to the inbox + // GCS, plus any policy error. When the policy has + // allow_log_provider_dropping := true, providers not on the allow-list are + // silently dropped from the returned slice; otherwise the whole call is + // denied (returning a non-nil error) if any provider is not allowed. + EnforceLogProviderPolicy(ctx context.Context, providerNames []string) ([]string, error) WithMetadataRollback(fn func() error) error } @@ -328,6 +336,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context.C return nil } +func (OpenDoorSecurityPolicyEnforcer) EnforceLogProviderPolicy(_ context.Context, providerNames []string) ([]string, error) { + return providerNames, nil +} + func (OpenDoorSecurityPolicyEnforcer) WithMetadataRollback(fn func() error) error { return fn() } @@ -461,6 +473,10 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context return errors.New("registry changes are denied by policy") } +func (ClosedDoorSecurityPolicyEnforcer) EnforceLogProviderPolicy(context.Context, []string) ([]string, error) { + return nil, errors.New("log provider is denied by policy") +} + func (ClosedDoorSecurityPolicyEnforcer) WithMetadataRollback(fn func() error) error { return fn() } diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 5e196ebd9a..b2ed0f5d1b 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -583,6 +583,38 @@ func getEnvsToKeep(envList []string, results rpi.RegoQueryResult) ([]string, err return keepSet.toArray(), nil } +// getProvidersToKeep extracts the "providers_to_keep" field from the rego +// query results for log_provider enforcement and intersects it (defensively) +// with the originally requested providerNames. Mirrors getEnvsToKeep. +// +// When the policy is older / does not return providers_to_keep, the entire +// requested list is kept (analogous to env_list). +func getProvidersToKeep(providerNames []string, results rpi.RegoQueryResult) ([]string, error) { + value, err := results.Value("providers_to_keep") + if err != nil || value == nil { + // policy did not return 'providers_to_keep'. Interpret as + // "proceed with provided providers". + return providerNames, nil + } + + providersAsInterfaces, ok := value.([]interface{}) + if !ok { + return nil, fmt.Errorf("policy returned incorrect type for 'providers_to_keep', expected []interface{}, received %T", value) + } + + keepSet := make(stringSet) + for _, providerAsInterface := range providersAsInterfaces { + if provider, ok := providerAsInterface.(string); ok { + keepSet.add(provider) + } else { + return nil, fmt.Errorf("members of providers_to_keep from policy must be strings, received %T", providerAsInterface) + } + } + + keepSet = keepSet.intersect(toStringSet(providerNames)) + return keepSet.toArray(), nil +} + func getCapsToKeep(capsList *oci.LinuxCapabilities, results rpi.RegoQueryResult) (*oci.LinuxCapabilities, error) { value, err := results.Value("caps_list") if err != nil || value == nil { @@ -1193,6 +1225,17 @@ func (policy *regoEnforcer) EnforceRegistryChangesPolicy(ctx context.Context, co return err } +func (policy *regoEnforcer) EnforceLogProviderPolicy(ctx context.Context, providerNames []string) ([]string, error) { + input := inputData{ + "providers": providerNames, + } + results, err := policy.enforce(ctx, "log_provider", input) + if err != nil { + return nil, err + } + return getProvidersToKeep(providerNames, results) +} + func (policy *regoEnforcer) GetUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return GetAllUserInfo(process, rootPath) } diff --git a/test/pkg/securitypolicy/policy.go b/test/pkg/securitypolicy/policy.go index eeb93f67c8..bdf748bac6 100644 --- a/test/pkg/securitypolicy/policy.go +++ b/test/pkg/securitypolicy/policy.go @@ -64,6 +64,7 @@ func PolicyWithOpts(tb testing.TB, policyType string, pOpts ...securitypolicy.Po config.AllowEnvironmentVariableDropping, config.AllowUnencryptedScratch, config.AllowCapabilityDropping, + config.AllowLogProviderDropping, ) if err != nil { tb.Fatal(err)