diff --git a/packages/google-auth/google/auth/_regional_access_boundary_utils.py b/packages/google-auth/google/auth/_regional_access_boundary_utils.py index c97bf8f484df..01dbe80169e4 100644 --- a/packages/google-auth/google/auth/_regional_access_boundary_utils.py +++ b/packages/google-auth/google/auth/_regional_access_boundary_utils.py @@ -27,7 +27,7 @@ from google.auth import _helpers from google.auth import environment_vars -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER import google.auth.credentials import google.auth.transport @@ -455,6 +455,61 @@ def start_refresh(self, credentials, request, rab_manager): self._worker.start() +def _prepare_async_lookup_callable(request): + """Unwraps a request callable, clones the transport, and returns the new callable. + + Args: + request: The original request callable (e.g. functools.partial or raw Request). + + Returns: + Tuple[Callable, Any, bool]: A tuple containing the new lookup callable, the + underlying request object, and a boolean indicating if it was cloned. + """ + is_partial = isinstance(request, functools.partial) + base_callable = request.func if is_partial else request + + if not hasattr(base_callable, "_clone"): + return request, base_callable, False + + cloned_callable = base_callable._clone() + is_cloned = cloned_callable is not base_callable + + if is_partial: + new_request = functools.partial( + cloned_callable, *request.args, **request.keywords + ) + else: + new_request = cloned_callable + + return new_request, cloned_callable, is_cloned + + +async def _close_cloned_request(lookup_request, is_cloned): + """Safely closes the underlying cloned request transport, if applicable. + + Args: + lookup_request (Any): The request object/transport to close. + is_cloned (bool): Whether the request was actually cloned. + """ + if not is_cloned or not hasattr(lookup_request, "close"): + return + + is_async = False + try: + maybe_coro = lookup_request.close() + if is_async := inspect.iscoroutine(maybe_coro): + await maybe_coro + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + adapter_type = " asynchronous " if is_async else " " + _LOGGER.warning( + "Failed to cleanly close cloned%srequest transport: %s", + adapter_type, + e, + exc_info=True, + ) + + class _AsyncRegionalAccessBoundaryRefreshManager(object): """Manages a task for background refreshing of the Regional Access Boundary in async flows.""" @@ -492,10 +547,18 @@ def start_refresh(self, credentials, request, rab_manager): return async def _worker(): + lookup_request = None + is_cloned = False try: - # credentials._lookup_regional_access_boundary should be async in the async creds class + ( + lookup_callable, + lookup_request, + is_cloned, + ) = _prepare_async_lookup_callable(request) regional_access_boundary_info = ( - await credentials._lookup_regional_access_boundary(request) + await credentials._lookup_regional_access_boundary( + lookup_callable + ) ) except Exception as e: if _helpers.is_logging_enabled(_LOGGER): @@ -505,6 +568,8 @@ async def _worker(): exc_info=True, ) regional_access_boundary_info = None + finally: + await _close_cloned_request(lookup_request, is_cloned) rab_manager.process_regional_access_boundary_info( regional_access_boundary_info diff --git a/packages/google-auth/google/auth/aio/transport/__init__.py b/packages/google-auth/google/auth/aio/transport/__init__.py index 166a3be50914..343711272a95 100644 --- a/packages/google-auth/google/auth/aio/transport/__init__.py +++ b/packages/google-auth/google/auth/aio/transport/__init__.py @@ -142,3 +142,13 @@ async def close(self) -> None: Close the underlying session. """ raise NotImplementedError("close must be implemented.") + + def _clone(self) -> "Request": + """Creates a copy of this request adapter. + + The base implementation returns `self` (an identical shared instance). + Transport adapters that maintain internal connection pools or stateful + sessions must override this method to return an independent, detached + adapter instance. + """ + return self diff --git a/packages/google-auth/google/auth/aio/transport/aiohttp.py b/packages/google-auth/google/auth/aio/transport/aiohttp.py index 642d15927d0f..d2508489f0f4 100644 --- a/packages/google-auth/google/auth/aio/transport/aiohttp.py +++ b/packages/google-auth/google/auth/aio/transport/aiohttp.py @@ -36,7 +36,7 @@ else: try: from aiohttp import ClientTimeout - except (ImportError, AttributeError): + except (ImportError, AttributeError): # pragma: NO COVER ClientTimeout = None _LOGGER = logging.getLogger(__name__) @@ -203,3 +203,74 @@ async def close(self) -> None: if not self._closed and self._session: await self._session.close() self._closed = True + + def _clone(self) -> "Request": + """Creates an independent copy of this request adapter. + + Returns: + google.auth.aio.transport.aiohttp.Request: A request adapter copy + running a new aiohttp.ClientSession with identical connection, + proxy, and session configurations. + """ + if self._closed: + raise exceptions.TransportError("Cannot clone a closed transport.") + + if not self._session: + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=True, + ) + return Request(session=new_session) + + session_kwargs: dict = { + "auto_decompress": False, + "trust_env": getattr(self._session, "_trust_env", True), + } + + # Copy underlying connection pool settings (SSL context, IP bindings, limits). + orig_connector = getattr(self._session, "_connector", None) + if orig_connector and not orig_connector.closed: + if isinstance(orig_connector, aiohttp.TCPConnector): + # We explicitly do not copy the resolver. The connector + # owns the resolver, and closing the cloned session would + # close the shared resolver, breaking the original session. + session_kwargs["connector"] = aiohttp.TCPConnector( + ssl=getattr(orig_connector, "_ssl", None), # type: ignore + limit=getattr(orig_connector, "_limit", 100), + limit_per_host=getattr(orig_connector, "_limit_per_host", 0), + force_close=getattr(orig_connector, "_force_close", False), + local_addr=getattr(orig_connector, "_local_addr", None), + ) + elif getattr(aiohttp, "UnixConnector", None) and isinstance( + orig_connector, getattr(aiohttp, "UnixConnector") + ): + path = getattr(orig_connector, "_path", None) + if path: + session_kwargs["connector"] = aiohttp.UnixConnector( + path=path, + limit=getattr(orig_connector, "_limit", 100), + force_close=getattr(orig_connector, "_force_close", False), + ) + else: + raise exceptions.TransportError( + f"Unsupported connector type for cloning: {type(orig_connector)}" + ) + + # Preserve distributed tracing configurations. + trace_configs = getattr(self._session, "_trace_configs", None) + if trace_configs: + session_kwargs["trace_configs"] = list(trace_configs) + + # Copy session-level defaults (headers, cookies, auth, timeout). + for attr_name, kwarg_name in [ + ("_default_headers", "headers"), + ("_cookie_jar", "cookie_jar"), + ("_default_auth", "auth"), + ("_timeout", "timeout"), + ("_json_serialize", "json_serialize"), + ]: + val = getattr(self._session, attr_name, None) + if val is not None: + session_kwargs[kwarg_name] = val + + return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore diff --git a/packages/google-auth/google/auth/transport/_aiohttp_requests.py b/packages/google-auth/google/auth/transport/_aiohttp_requests.py index e8321965e0db..f82131a12e7f 100644 --- a/packages/google-auth/google/auth/transport/_aiohttp_requests.py +++ b/packages/google-auth/google/auth/transport/_aiohttp_requests.py @@ -203,6 +203,83 @@ async def __call__( new_exc = exceptions.TransportError(caught_exc) raise new_exc from caught_exc + def _clone(self): + """Create an independent detached copy of this request adapter. + + Returns: + google.auth.transport._aiohttp_requests.Request: An independent request adapter + running an isolated aiohttp.ClientSession with identical environment proxy and + observability configurations. + """ + if getattr(self, "_closed", False): + raise exceptions.TransportError("Cannot clone a closed transport.") + + if not self.session: + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=True, + ) + return Request(session=new_session) + + session_kwargs: dict = { + "auto_decompress": False, + "trust_env": getattr(self.session, "_trust_env", True), + } + + # Copy underlying connection pool settings (SSL context, IP bindings, limits). + orig_connector = getattr(self.session, "_connector", None) + if orig_connector and not getattr(orig_connector, "closed", True): + if isinstance(orig_connector, aiohttp.TCPConnector): + # We explicitly do not copy the resolver. The connector + # owns the resolver, and closing the cloned session would + # close the shared resolver, breaking the original session. + session_kwargs["connector"] = aiohttp.TCPConnector( + ssl=getattr(orig_connector, "_ssl", None), # type: ignore + limit=getattr(orig_connector, "_limit", 100), + limit_per_host=getattr(orig_connector, "_limit_per_host", 0), + force_close=getattr(orig_connector, "_force_close", False), + local_addr=getattr(orig_connector, "_local_addr", None), + ) + elif getattr(aiohttp, "UnixConnector", None) and isinstance( + orig_connector, getattr(aiohttp, "UnixConnector") + ): + path = getattr(orig_connector, "_path", None) + if path: + session_kwargs["connector"] = aiohttp.UnixConnector( + path=path, + limit=getattr(orig_connector, "_limit", 100), + force_close=getattr(orig_connector, "_force_close", False), + ) + else: + raise exceptions.TransportError( + f"Unsupported connector type for cloning: {type(orig_connector)}" + ) + + # Preserve distributed tracing configurations. + trace_configs = getattr(self.session, "_trace_configs", None) + if trace_configs: + session_kwargs["trace_configs"] = list(trace_configs) + + # Copy session-level defaults (headers, cookies, auth, timeout). + for attr_name, kwarg_name in [ + ("_default_headers", "headers"), + ("_cookie_jar", "cookie_jar"), + ("_default_auth", "auth"), + ("_timeout", "timeout"), + ("_json_serialize", "json_serialize"), + ]: + val = getattr(self.session, attr_name, None) + if val is not None: + session_kwargs[kwarg_name] = val + + return Request(session=aiohttp.ClientSession(**session_kwargs)) # type: ignore + + async def close(self): + """Cleanly release the underlying aiohttp ClientSession resources.""" + if not getattr(self, "_closed", False) and self.session: + await self.session.close() + self._closed = True + class AuthorizedSession(aiohttp.ClientSession): """This is an async implementation of the Authorized Session class. We utilize an diff --git a/packages/google-auth/tests/test__regional_access_boundary_utils.py b/packages/google-auth/tests/test__regional_access_boundary_utils.py index c612b60b8ed2..e2f69c2cea5b 100644 --- a/packages/google-auth/tests/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests/test__regional_access_boundary_utils.py @@ -678,6 +678,7 @@ async def test_async_refresh_manager_session_closed_ignored(self): ) request = mock.Mock() + request._clone.return_value = request rab_manager = mock.Mock() manager = ( @@ -694,6 +695,120 @@ async def test_async_refresh_manager_session_closed_ignored(self): credentials._lookup_regional_access_boundary.assert_called_once_with(request) rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + @pytest.mark.asyncio + async def test_start_refresh_async_clones_request_and_unwraps_partial(self): + import functools + + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.return_value = { + "encodedLocations": "0xA30" + } + + mock_request = mock.Mock() + mock_cloned_request = mock.Mock() + mock_request._clone.return_value = mock_cloned_request + mock_cloned_request.close = mock.AsyncMock() + + # Wrap in a functools.partial to simulate AuthorizedSession.request() timeouts + partial_request = functools.partial(mock_request, timeout=180) + + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + manager.start_refresh(credentials, partial_request, rab_manager) + + await manager._worker_task + + # Verify that actual_request._clone() was called + mock_request._clone.assert_called_once() + + # Verify that the lookup ran on a re-wrapped partial of the cloned request + called_arg = credentials._lookup_regional_access_boundary.call_args[0][0] + assert isinstance(called_arg, functools.partial) + assert called_arg.func is mock_cloned_request + assert called_arg.keywords == {"timeout": 180} + + # Verify that the cloned request was closed cleanly in the finally block + mock_cloned_request.close.assert_awaited_once() + rab_manager.process_regional_access_boundary_info.assert_called_once_with( + {"encodedLocations": "0xA30"} + ) + + @pytest.mark.asyncio + async def test_start_refresh_suppresses_request_clone_exception(self): + from google.auth import exceptions + + credentials = mock.AsyncMock() + + request = mock.Mock() + request._clone.side_effect = exceptions.TransportError( + "Cannot clone a closed transport." + ) + + rab_manager = mock.Mock() + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + manager.start_refresh(credentials, request, rab_manager) + await manager._worker_task + + credentials._lookup_regional_access_boundary.assert_not_called() + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + + @pytest.mark.asyncio + async def test_start_refresh_async_mimics_ephemeral_session_closed_bug(self): + # Specifically mimics the real-world race condition where a fast foreground main call + # pulls the rug out from under the background worker when using an un-cloned session. + import asyncio + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + worker_started_event = asyncio.Event() + foreground_closed_event = asyncio.Event() + + class EphemeralRequest: + def __init__(self): + self.closed = False + + async def __call__(self, *args, **kwargs): + worker_started_event.set() + await foreground_closed_event.wait() + if self.closed: + raise RuntimeError("Session is closed") + return "success" + + ephemeral_req = EphemeralRequest() + + credentials = mock.AsyncMock() + + async def mock_lookup(req): + return await req() + + credentials._lookup_regional_access_boundary.side_effect = mock_lookup + + rab_manager = mock.Mock() + + # Start the background refresh worker + manager.start_refresh(credentials, ephemeral_req, rab_manager) + + # Wait until the background worker has actually started its speculative request + await worker_started_event.wait() + + # Simulate fast foreground primary call closing the session + ephemeral_req.closed = True + foreground_closed_event.set() + + # Await the background worker task to settle + await manager._worker_task + + # Verify that the background worker hit the "Session is closed" error and failed open cleanly + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + def test_get_service_account_rab_endpoint(monkeypatch): from google.auth.transport import _mtls_helper @@ -761,3 +876,58 @@ def test_get_workload_identity_pool_rab_endpoint(monkeypatch): url == "https://iamcredentials.mtls.googleapis.com/v1/projects/PROJECT_NUM/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" ) + + +def test_sync_refresh_manager_pickle(): + import pickle + + manager = _regional_access_boundary_utils._RegionalAccessBoundaryRefreshManager() + manager._worker = mock.Mock() + + dumped = pickle.dumps(manager) + loaded = pickle.loads(dumped) + + assert loaded._lock is not None + assert loaded._worker is None + + +def test_manager_eq_different_type(): + manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() + assert manager != "not a manager" + + +def test_set_initial_regional_access_boundary_empty(): + manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() + manager.set_initial_regional_access_boundary( + encoded_locations="", expiry=datetime.datetime.now() + ) + assert manager._data.encoded_locations == "" + assert manager._data.expiry is None + + +def test_set_initial_regional_access_boundary_with_value(): + manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() + expiry = datetime.datetime.now() + manager.set_initial_regional_access_boundary( + encoded_locations="us-east1", expiry=expiry + ) + assert manager._data.encoded_locations == "us-east1" + assert manager._data.expiry == expiry + + +def test_sync_refresh_manager_start_refresh_executes(): + manager = _regional_access_boundary_utils._RegionalAccessBoundaryRefreshManager() + creds = mock.Mock() + request = mock.Mock() + rab_manager = mock.Mock() + + with mock.patch( + "google.auth._regional_access_boundary_utils._RegionalAccessBoundaryRefreshThread" + ) as mock_thread_class: + mock_thread = mock.Mock() + mock_thread_class.return_value = mock_thread + + manager.start_refresh(creds, request, rab_manager) + + mock_thread_class.assert_called_once() + mock_thread.start.assert_called_once() diff --git a/packages/google-auth/tests/transport/aio/test_aiohttp.py b/packages/google-auth/tests/transport/aio/test_aiohttp.py index 553f35775fac..68acac6f7619 100644 --- a/packages/google-auth/tests/transport/aio/test_aiohttp.py +++ b/packages/google-auth/tests/transport/aio/test_aiohttp.py @@ -169,3 +169,152 @@ async def test_request_call_raises_transport_error_for_closed_session( exc.match("session is closed.") aiohttp_request._closed = False + + async def test_request_clone(self): + request = auth_aiohttp.Request() + cloned = request._clone() + assert cloned is not request + assert isinstance(cloned, auth_aiohttp.Request) + assert cloned._session is not request._session + await request.close() + await cloned.close() + + async def test_request_close(self): + request = auth_aiohttp.Request() + assert not getattr(request, "_closed", False) + await request.close() + assert request._closed + # Second call should be idempotent + await request.close() + assert request._closed + + async def test_request_clone_closed_session_raises(self): + request = auth_aiohttp.Request() + await request.close() + with pytest.raises(exceptions.TransportError) as exc: + request._clone() + exc.match("Cannot clone a closed transport.") + + async def test_request_clone_with_active_session(self): + import ssl + from aiohttp import BasicAuth, ClientTimeout, TCPConnector + + custom_ssl = ssl.create_default_context() + custom_connector = TCPConnector( + ssl=custom_ssl, + limit=42, + limit_per_host=12, + force_close=True, + local_addr=("127.0.0.2", 0), + ) + + mock_session = aiohttp.ClientSession( + connector=custom_connector, + headers={"x-corporate-firewall": "open"}, + cookies={"enterprise_session": "active"}, + auth=BasicAuth("admin", "secret"), + timeout=ClientTimeout(total=84.0), + trust_env=True, + trace_configs=[aiohttp.TraceConfig()], + ) + request = auth_aiohttp.Request(session=mock_session) + + cloned = request._clone() + + assert cloned is not request + assert cloned._session is not mock_session + assert cloned._session is not None + + # Verify underlying TCPConnector configuration + cloned_connector = cloned._session._connector + assert isinstance(cloned_connector, TCPConnector) + assert cloned_connector is not custom_connector + assert cloned_connector._resolver is not custom_connector._resolver + assert cloned_connector._ssl is custom_ssl + assert cloned_connector._limit == 42 + assert cloned_connector._limit_per_host == 12 + assert cloned_connector._force_close is True + assert cloned_connector._local_addr == ("127.0.0.2", 0) + + # Verify session-level configuration + assert cloned._session._trust_env is True + assert len(cloned._session._trace_configs) == 1 + assert cloned._session._default_headers == {"x-corporate-firewall": "open"} + assert cloned._session._cookie_jar is mock_session._cookie_jar + assert cloned._session._default_auth == mock_session._default_auth + assert cloned._session._timeout == ClientTimeout(total=84.0) + + await request.close() + await cloned.close() + + async def test_request_clone_unix_socket(self): + try: + from aiohttp import UnixConnector + except ImportError: + return # Windows or environment without Unix Domain Sockets + + connector = UnixConnector(path="/var/run/enterprise.sock", limit=42) + mock_session = aiohttp.ClientSession(connector=connector) + request = auth_aiohttp.Request(session=mock_session) + + cloned = request._clone() + + assert cloned._session is not None + cloned_connector = cloned._session._connector + assert isinstance(cloned_connector, UnixConnector) + assert cloned_connector._path == "/var/run/enterprise.sock" + assert cloned_connector._limit == 42 + + await request.close() + await cloned.close() + + async def test_request_call_raises_timeout_error_int(self, aiohttp_request): + with aioresponses() as m: + m.get("http://example.com", exception=asyncio.TimeoutError) + with pytest.raises(exceptions.TimeoutError) as exc: + await aiohttp_request("http://example.com", timeout=120) + exc.match("Request timed out after 120 seconds.") + + async def test_request_clone_with_closed_connector(self): + session = aiohttp.ClientSession() + request = auth_aiohttp.Request(session=session) + await session.close() + + cloned = request._clone() + assert cloned is not request + assert cloned._session is not None + await request.close() + await cloned.close() + + async def test_request_clone_with_custom_connector(self): + session = aiohttp.ClientSession() + custom_connector = AsyncMock() + custom_connector.closed = False + custom_connector.close = AsyncMock() + session._connector = custom_connector + + request = auth_aiohttp.Request(session=session) + with pytest.raises( + exceptions.TransportError, match="Unsupported connector type for cloning" + ): + request._clone() + await request.close() + + async def test_request_clone_unix_socket_no_path(self): + try: + from aiohttp import UnixConnector + except ImportError: + return + + session = aiohttp.ClientSession() + connector = UnixConnector(path="/tmp/test.sock") + connector._path = None + session._connector = connector + + request = auth_aiohttp.Request(session=session) + cloned = request._clone() + assert cloned is not request + assert cloned._session is not None + assert cloned._session._connector is not connector + await request.close() + await cloned.close() diff --git a/packages/google-auth/tests/transport/aio/test_sessions.py b/packages/google-auth/tests/transport/aio/test_sessions.py index 9780b8e2a1d2..58643c653ca2 100644 --- a/packages/google-auth/tests/transport/aio/test_sessions.py +++ b/packages/google-auth/tests/transport/aio/test_sessions.py @@ -334,3 +334,9 @@ async def test_http_delete_method_success(self): response = await authed_session.delete(self.TEST_URL) assert await response.read() == expected_payload response = await authed_session.close() + + +def test_mock_request_clone(): + request = MockRequest() + cloned = request._clone() + assert cloned is request diff --git a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py index 268ee37261c8..121cc9595b6b 100644 --- a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py @@ -28,6 +28,7 @@ async def test_async_refresh_manager_start_refresh(): } request = mock.Mock() + request._clone.return_value = request rab_manager = mock.Mock() manager = ( @@ -82,3 +83,161 @@ async def controlled_lookup(*args, **kwargs): # Verify that the second refresh request was ignored and only one lookup occurred. assert credentials._lookup_regional_access_boundary.call_count == 1 + + +def test_prepare_async_lookup_callable_no_clone(): + request = mock.Mock(spec=[]) # explicitly no _clone + ( + new_request, + cloned, + is_cloned, + ) = _regional_access_boundary_utils._prepare_async_lookup_callable(request) + assert new_request is request + assert cloned is request + assert is_cloned is False + + +def test_prepare_async_lookup_callable_with_clone(): + request = mock.Mock() + cloned_req = mock.Mock() + request._clone.return_value = cloned_req + + ( + new_request, + cloned, + is_cloned, + ) = _regional_access_boundary_utils._prepare_async_lookup_callable(request) + assert new_request is cloned_req + assert cloned is cloned_req + assert is_cloned is True + + +def test_prepare_async_lookup_callable_partial(): + import functools + + request = mock.Mock() + cloned_req = mock.Mock() + request._clone.return_value = cloned_req + + partial_req = functools.partial(request, 1, a=2) + ( + new_request, + cloned, + is_cloned, + ) = _regional_access_boundary_utils._prepare_async_lookup_callable(partial_req) + + assert isinstance(new_request, functools.partial) + assert new_request.func is cloned_req + assert new_request.args == (1,) + assert new_request.keywords == {"a": 2} + assert cloned is cloned_req + assert is_cloned is True + + +@pytest.mark.asyncio +async def test_close_cloned_request_not_cloned(): + request = mock.Mock() + await _regional_access_boundary_utils._close_cloned_request( + request, is_cloned=False + ) + request.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_close_cloned_request_sync(): + request = mock.Mock() + await _regional_access_boundary_utils._close_cloned_request(request, is_cloned=True) + request.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_close_cloned_request_async(): + request = mock.Mock() + request.close = mock.AsyncMock() + await _regional_access_boundary_utils._close_cloned_request(request, is_cloned=True) + request.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_close_cloned_request_async_exception(): + request = mock.Mock() + request.close = mock.AsyncMock(side_effect=Exception("close error")) + # Should swallow the exception and not raise + await _regional_access_boundary_utils._close_cloned_request(request, is_cloned=True) + request.close.assert_awaited_once() + + +def test_async_refresh_manager_pickle(): + import pickle + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + manager._worker_task = mock.Mock() + + dumped = pickle.dumps(manager) + loaded = pickle.loads(dumped) + + assert loaded._lock is not None + assert loaded._worker_task is None + + +@pytest.mark.asyncio +async def test_async_worker_exception_logging_enabled(monkeypatch): + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.side_effect = Exception("lookup fail") + + request = mock.Mock() + request._clone.return_value = request + rab_manager = mock.Mock() + + # Force is_logging_enabled to return True + monkeypatch.setattr( + _regional_access_boundary_utils._helpers, + "is_logging_enabled", + lambda logger: True, + ) + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + with mock.patch.object( + _regional_access_boundary_utils._LOGGER, "warning" + ) as mock_warning: + manager.start_refresh(credentials, request, rab_manager) + await manager._worker_task + + mock_warning.assert_called_once() + assert "lookup raised an exception" in mock_warning.call_args[0][0] + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + + +@pytest.mark.asyncio +async def test_async_worker_exception_logging_disabled(monkeypatch): + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.side_effect = Exception("lookup fail") + + request = mock.Mock() + request._clone.return_value = request + rab_manager = mock.Mock() + + # Force is_logging_enabled to return False + monkeypatch.setattr( + _regional_access_boundary_utils._helpers, + "is_logging_enabled", + lambda logger: False, + ) + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + with mock.patch.object( + _regional_access_boundary_utils._LOGGER, "warning" + ) as mock_warning: + manager.start_refresh(credentials, request, rab_manager) + await manager._worker_task + + mock_warning.assert_not_called() + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) diff --git a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py index d6a24da2e302..b0ce3aac8e51 100644 --- a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py +++ b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py @@ -128,6 +128,204 @@ def test_timeout(self): request = aiohttp_requests.Request(http) request(url="http://example.com", method="GET", timeout=5) + @pytest.mark.asyncio + async def test__clone(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http._connector = mock.Mock(spec=aiohttp.TCPConnector) + http._connector.closed = False + http._connector._ssl = mock.sentinel.ssl + http._connector._limit = 50 + http._connector._limit_per_host = 10 + http._connector._force_close = True + http._connector._resolver = mock.sentinel.resolver + http._connector._local_addr = mock.sentinel.local_addr + + http._trust_env = False + http._trace_configs = [mock.sentinel.trace_config] + http._default_headers = {"test": "header"} + http._cookie_jar = mock.sentinel.cookie_jar + http._default_auth = mock.sentinel.auth + http._timeout = mock.sentinel.timeout + http._json_serialize = mock.sentinel.json_serialize + + request = aiohttp_requests.Request(http) + with mock.patch( + "aiohttp.ClientSession", autospec=True + ) as session_mock, mock.patch.object( + aiohttp.TCPConnector, "__init__", autospec=True, return_value=None + ) as connector_init_mock: + session_mock.return_value._auto_decompress = False + cloned = request._clone() + + assert isinstance(cloned, aiohttp_requests.Request) + assert cloned is not request + + connector_init_mock.assert_called_once_with( + mock.ANY, + ssl=mock.sentinel.ssl, + limit=50, + limit_per_host=10, + force_close=True, + local_addr=mock.sentinel.local_addr, + ) + + session_mock.assert_called_once_with( + connector=mock.ANY, + auto_decompress=False, + trust_env=False, + trace_configs=[mock.sentinel.trace_config], + headers={"test": "header"}, + cookie_jar=mock.sentinel.cookie_jar, + auth=mock.sentinel.auth, + timeout=mock.sentinel.timeout, + json_serialize=mock.sentinel.json_serialize, + ) + assert isinstance(session_mock.call_args[1]["connector"], aiohttp.TCPConnector) + + @pytest.mark.asyncio + async def test__clone_closed(self): + request = aiohttp_requests.Request() + request._closed = True + with pytest.raises( + google.auth.exceptions.TransportError, + match="Cannot clone a closed transport.", + ): + request._clone() + + @pytest.mark.asyncio + async def test__clone_custom_connector(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http._connector = mock.Mock() + http._connector.closed = False + request = aiohttp_requests.Request(http) + with pytest.raises( + google.auth.exceptions.TransportError, + match="Unsupported connector type for cloning", + ): + request._clone() + + @pytest.mark.asyncio + async def test_close(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http.close = mock.AsyncMock() + request = aiohttp_requests.Request(http) + + await request.close() + assert request._closed is True + http.close.assert_awaited_once() + + # Check idempotency + await request.close() + http.close.assert_awaited_once() # Still only called 1 time + + @pytest.mark.asyncio + async def test__clone_no_session(self): + request = aiohttp_requests.Request() + cloned = request._clone() + assert isinstance(cloned, aiohttp_requests.Request) + assert cloned is not request + assert cloned.session is not None + await cloned.close() + + @pytest.mark.asyncio + async def test__clone_closed_connector(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http._connector = mock.Mock() + http._connector.closed = True + http._trust_env = True + http._trace_configs = None + http._default_headers = None + http._cookie_jar = None + http._default_auth = None + http._timeout = None + http._json_serialize = None + + request = aiohttp_requests.Request(http) + with mock.patch("aiohttp.ClientSession", autospec=True) as session_mock: + session_mock.return_value._auto_decompress = False + cloned = request._clone() + + assert isinstance(cloned, aiohttp_requests.Request) + assert cloned is not request + + @pytest.mark.asyncio + async def test__clone_unix_socket_no_path(self): + try: + from aiohttp import UnixConnector + except ImportError: + return + + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http._connector = mock.Mock(spec=UnixConnector) + http._connector.closed = False + http._connector._path = None + http._trust_env = True + http._trace_configs = None + http._default_headers = None + http._cookie_jar = None + http._default_auth = None + http._timeout = None + http._json_serialize = None + + request = aiohttp_requests.Request(http) + with mock.patch("aiohttp.ClientSession", autospec=True) as session_mock: + session_mock.return_value._auto_decompress = False + cloned = request._clone() + + assert isinstance(cloned, aiohttp_requests.Request) + assert cloned is not request + + @pytest.mark.asyncio + async def test__clone_unix_socket_with_path(self): + try: + from aiohttp import UnixConnector + except ImportError: + return + + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + http._connector = mock.Mock(spec=UnixConnector) + http._connector.closed = False + http._connector._path = "/tmp/test.sock" + http._connector._limit = 42 + http._connector._force_close = True + http._trust_env = True + http._trace_configs = None + http._default_headers = None + http._cookie_jar = None + http._default_auth = None + http._timeout = None + http._json_serialize = None + + request = aiohttp_requests.Request(http) + with mock.patch( + "aiohttp.ClientSession", autospec=True + ) as session_mock, mock.patch.object( + UnixConnector, "__init__", autospec=True, return_value=None + ) as connector_init_mock: + session_mock.return_value._auto_decompress = False + cloned = request._clone() + + assert isinstance(cloned, aiohttp_requests.Request) + assert cloned is not request + connector_init_mock.assert_called_once_with( + mock.ANY, + path="/tmp/test.sock", + limit=42, + force_close=True, + ) + class CredentialsStub(google.auth._credentials_async.Credentials): def __init__(self, token="token"):