Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import asyncio
import copy
import datetime
import functools
import inspect
import logging
import os
import threading
from typing import NamedTuple, Optional, TYPE_CHECKING

from google.auth import _helpers
from google.auth import environment_vars

if TYPE_CHECKING:
import google.auth.credentials
Expand All @@ -34,25 +31,6 @@
_LOGGER = logging.getLogger(__name__)


@functools.lru_cache()
def is_regional_access_boundary_enabled():
"""Checks if Regional Access Boundary is enabled via environment variable.

The environment variable is interpreted as a boolean with the following
(case-insensitive) rules:
- "true", "1" are considered true.
- Any other value (or unset) is considered false.

Returns:
bool: True if Regional Access Boundary is enabled, False otherwise.
"""
value = os.environ.get(environment_vars.GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED)
if value is None:
return False

return value.lower() in ("true", "1")


# The default lifetime for a cached Regional Access Boundary.
DEFAULT_REGIONAL_ACCESS_BOUNDARY_TTL = datetime.timedelta(hours=6)

Expand Down
8 changes: 3 additions & 5 deletions packages/google-auth/google/auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,11 +841,9 @@ def from_info(cls, info, **kwargs):
Raises:
ValueError: For invalid parameters.
"""
aws_security_credentials_supplier = info.get(
"aws_security_credentials_supplier"
)
kwargs.update(
{"aws_security_credentials_supplier": aws_security_credentials_supplier}
kwargs.setdefault(
"aws_security_credentials_supplier",
info.get("aws_security_credentials_supplier"),
)
return super(Credentials, cls).from_info(info, **kwargs)

Expand Down
17 changes: 8 additions & 9 deletions packages/google-auth/google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,13 @@ def _is_regional_endpoint(self, url):
try:
# Do not perform a lookup if the request is for a regional endpoint.
hostname = urlparse(url).hostname
if hostname and (
hostname.endswith(".rep.googleapis.com")
or hostname.endswith(".rep.sandbox.googleapis.com")
if hostname and hostname.endswith(
Comment thread
lsirac marked this conversation as resolved.
(
".rep.googleapis.com",
".rep.sandbox.googleapis.com",
".rep.mtls.googleapis.com",
".rep.mtls.sandbox.googleapis.com",
)
):
return True
except (ValueError, TypeError, AttributeError):
Expand Down Expand Up @@ -484,16 +488,11 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url):
def _is_regional_access_boundary_lookup_required(self):
"""Checks if a Regional Access Boundary lookup is required.

A lookup is required if the feature is enabled via an environment
variable and the universe domain is supported.
A lookup is required if the universe domain is supported.

Returns:
bool: True if a Regional Access Boundary lookup is required, False otherwise.
"""
# Check if the feature is enabled.
if not _regional_access_boundary_utils.is_regional_access_boundary_enabled():
return False

# Skip for non-default universe domains.
if self.universe_domain != DEFAULT_UNIVERSE_DOMAIN:
return False
Expand Down
6 changes: 5 additions & 1 deletion packages/google-auth/google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,13 @@
AWS_REGION = "AWS_REGION"
AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION"


GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED = "GOOGLE_AUTH_TRUST_BOUNDARY_ENABLED"
"""Environment variable controlling whether to enable trust boundary feature.
The default value is false. Users have to explicitly set this value to true."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a publicly accessible value. I don't think it's safe to remove, in case someone imported it. Can we just leave it with a comment that it's deprecated?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


.. deprecated::
This environment variable is deprecated and no longer has any effect.
"""

GOOGLE_API_CERTIFICATE_CONFIG = "GOOGLE_API_CERTIFICATE_CONFIG"
"""Environment variable defining the location of Google API certificate config
Expand Down
33 changes: 30 additions & 3 deletions packages/google-auth/google/auth/external_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import json
import logging
import re
import threading
from typing import Optional, TYPE_CHECKING


Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
self._metrics_options = self._create_default_metrics_options()

self._impersonated_credentials = None
self._impersonation_lock = threading.Lock()
self._project_id = None
self._supplier_context = SupplierContext(
self._subject_token_type, self._audience
Expand All @@ -213,6 +215,15 @@ def __init__(
"credentials"
)

def __getstate__(self):
state = self.__dict__.copy()
state.pop("_impersonation_lock", None)
return state

def __setstate__(self, state):
super().__setstate__(state)
self._impersonation_lock = threading.Lock()

@property
def info(self):
"""Generates the dictionary representation of the current credentials.
Expand Down Expand Up @@ -444,6 +455,17 @@ def _maybe_start_regional_access_boundary_refresh(self, request, url):
HTTP requests.
url (str): The URL of the request.
"""
if self._should_initialize_impersonated_credentials():
with self._impersonation_lock:
if self._impersonated_credentials is None:
impersonated = self._initialize_impersonated_credentials()
if getattr(self, "token", None):
impersonated.token = self.token
if getattr(self, "expiry", None):
impersonated.expiry = self.expiry
self._impersonated_credentials = impersonated
self._rab_manager = impersonated._rab_manager

if getattr(self, "_impersonated_credentials", None):
self._impersonated_credentials._maybe_start_regional_access_boundary_refresh(
request, url
Expand All @@ -462,7 +484,11 @@ def _perform_refresh_token(self, request, cert_fingerprint=None):
)

if self._should_initialize_impersonated_credentials():
self._impersonated_credentials = self._initialize_impersonated_credentials()
with self._impersonation_lock:
if self._impersonated_credentials is None:
self._impersonated_credentials = (
self._initialize_impersonated_credentials()
)

if self._impersonated_credentials:
self._impersonated_credentials.refresh(request)
Expand Down Expand Up @@ -581,9 +607,10 @@ def with_universe_domain(self, universe_domain):
return cred

def _should_initialize_impersonated_credentials(self):
"""Determines if the underlying Service Account credential should be initialized."""
return (
self._service_account_impersonation_url is not None
and self._impersonated_credentials is None
getattr(self, "_service_account_impersonation_url", None) is not None
and getattr(self, "_impersonated_credentials", None) is None
)

def _initialize_impersonated_credentials(self):
Expand Down
3 changes: 1 addition & 2 deletions packages/google-auth/google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ def from_info(cls, info, **kwargs):
Raises:
ValueError: For invalid parameters.
"""
subject_token_supplier = info.get("subject_token_supplier")
kwargs.update({"subject_token_supplier": subject_token_supplier})
kwargs.setdefault("subject_token_supplier", info.get("subject_token_supplier"))
return super(Credentials, cls).from_info(info, **kwargs)

@classmethod
Expand Down
50 changes: 35 additions & 15 deletions packages/google-auth/tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import base64
import datetime
import re
from unittest import mock

import pytest # type: ignore
Expand Down Expand Up @@ -206,6 +207,7 @@ def test_before_request_refreshes(self, get):
"access_token": "token",
"expires_in": 500,
},
"googleapis.com",
]

# Credentials should start as invalid
Expand Down Expand Up @@ -252,7 +254,11 @@ def test_with_universe_domain(self):
assert creds.universe_domain == "universe_domain"
assert creds._universe_domain_cached

def test_token_usage_metrics(self):
@mock.patch(
"google.auth.compute_engine._metadata.get_universe_domain",
return_value="googleapis.com",
)
def test_token_usage_metrics(self, mock_get_universe_domain):
self.credentials.token = "token"
self.credentials.expiry = None

Expand Down Expand Up @@ -410,11 +416,7 @@ def test_build_regional_access_boundary_lookup_url_no_email(
url = creds._build_regional_access_boundary_lookup_url()
assert url is None

@mock.patch(
"google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled",
return_value=True,
)
def test_is_regional_access_boundary_lookup_required(self, mock_enabled):
def test_is_regional_access_boundary_lookup_required(self):
creds = self.credentials
creds._universe_domain_cached = True

Expand Down Expand Up @@ -442,15 +444,11 @@ def test_build_regional_access_boundary_lookup_url_with_invalid_email(self):
url = creds._build_regional_access_boundary_lookup_url()
assert url is None

@mock.patch(
"google.auth._regional_access_boundary_utils.is_regional_access_boundary_enabled",
return_value=True,
)
@mock.patch(
"google.auth.compute_engine._metadata.get_service_account_info", autospec=True
)
def test_regional_access_boundary_disabled_state_transitions(
self, mock_get_service_account_info, mock_enabled
self, mock_get_service_account_info
):
mock_get_service_account_info.return_value = {
"email": "spiffe://trust-domain/ns/ns/sa/sa",
Expand Down Expand Up @@ -769,6 +767,15 @@ def test_with_target_audience_integration(self):
json={},
)

# mock allowedLocations for Regional Access Boundary
responses.add(
responses.GET,
re.compile(r".*/allowedLocations$"),
status=200,
content_type="application/json",
json={"encodedLocations": "0xABC"},
)

# mock token for credentials
responses.add(
responses.GET,
Expand All @@ -787,8 +794,10 @@ def test_with_target_audience_integration(self):
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob",
re.compile(
r"https://iamcredentials\.(mtls\.)?googleapis\.com/v1/projects/-/"
r"serviceAccounts/service-account@example\.com:signBlob"
),
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down Expand Up @@ -951,12 +960,23 @@ def test_with_quota_project_integration(self):
json={},
)

# mock allowedLocations for Regional Access Boundary
responses.add(
responses.GET,
re.compile(r".*/allowedLocations$"),
status=200,
content_type="application/json",
json={"encodedLocations": "0xABC"},
)

# mock sign blob endpoint
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
responses.POST,
"https://iamcredentials.googleapis.com/v1/projects/-/"
"serviceAccounts/service-account@example.com:signBlob",
re.compile(
r"https://iamcredentials\.(mtls\.)?googleapis\.com/v1/projects/-/"
r"serviceAccounts/service-account@example\.com:signBlob"
),
status=200,
content_type="application/json",
json={"keyId": "some-key-id", "signedBlob": signature},
Expand Down
Loading
Loading