diff --git a/CHANGES/12824.bugfix.rst b/CHANGES/12824.bugfix.rst new file mode 100644 index 00000000000..f8dbd169c31 --- /dev/null +++ b/CHANGES/12824.bugfix.rst @@ -0,0 +1 @@ +Fixed :class:`~aiohttp.CookieJar` dropping the host-only flag of cookies when persisted with :meth:`~aiohttp.CookieJar.save` and reloaded with :meth:`~aiohttp.CookieJar.load`, so a cookie set without a ``Domain`` attribute is again scoped to the exact host that set it after a reload; the absolute expiration deadline is now persisted as well, so a reloaded cookie keeps its original lifetime instead of being rescheduled from the load time. :meth:`~aiohttp.CookieJar.load` now replaces the jar contents rather than merging onto prior state, and loaded cookies pass through the same acceptance rules as :meth:`~aiohttp.CookieJar.update_cookies`, so a cookie for an IP-address host is dropped when loaded into a jar created without ``unsafe=True`` -- by :user:`bdraco`. diff --git a/CHANGES/12825.bugfix.rst b/CHANGES/12825.bugfix.rst new file mode 100644 index 00000000000..8bfd90bff65 --- /dev/null +++ b/CHANGES/12825.bugfix.rst @@ -0,0 +1 @@ +Scoped :class:`~aiohttp.client_middleware_digest_auth.DigestAuthMiddleware` credentials to the origin of the first request it handles, so a redirect to a different origin no longer triggers a digest response computed from the configured credentials; a challenge from another origin is only answered when that origin falls within a protection space advertised by the anchor origin through the RFC 7616 ``domain`` directive -- by :user:`bdraco`. diff --git a/CHANGES/12826.bugfix.rst b/CHANGES/12826.bugfix.rst new file mode 100644 index 00000000000..7e095615d84 --- /dev/null +++ b/CHANGES/12826.bugfix.rst @@ -0,0 +1 @@ +Fixed the C HTTP parser not enforcing ``max_line_size`` on a request target or response reason phrase that is split across multiple reads; each fragment was checked on its own, so an accumulated line could exceed the limit without raising ``LineTooLong``. The accumulated length is now checked, matching the pure-Python parser -- by :user:`bdraco`. diff --git a/CHANGES/12827.bugfix.rst b/CHANGES/12827.bugfix.rst new file mode 100644 index 00000000000..9442867d363 --- /dev/null +++ b/CHANGES/12827.bugfix.rst @@ -0,0 +1 @@ +Changed :class:`~aiohttp.TCPConnector` to reject legacy non-canonical numeric IPv4 host forms such as ``2130706433``, ``017700000001`` and ``127.1`` with :exc:`~aiohttp.InvalidUrlClientError`; only canonical dotted-quad IPv4 literals are now treated as IP address literals, while every other host is sent through the configured resolver -- by :user:`bdraco`. diff --git a/CHANGES/12828.bugfix.rst b/CHANGES/12828.bugfix.rst new file mode 100644 index 00000000000..9893577a587 --- /dev/null +++ b/CHANGES/12828.bugfix.rst @@ -0,0 +1 @@ +Fixed :meth:`~aiohttp.StreamReader.readany` and :meth:`~aiohttp.StreamReader.read_nowait` joining data fed back into the buffer during the call (when draining below the low water mark resumes reading) into a single unbounded :class:`bytes`; a call now returns only the chunks that were buffered when it started, keeping the drain of an unread auto-decompressed request body bounded by the read buffer -- by :user:`bdraco`. diff --git a/CHANGES/12829.bugfix.rst b/CHANGES/12829.bugfix.rst new file mode 100644 index 00000000000..6bb27ccc934 --- /dev/null +++ b/CHANGES/12829.bugfix.rst @@ -0,0 +1 @@ +Fixed :class:`~aiohttp.ClientSession` with ``trust_env=True`` carrying a proxy's ``Proxy-Authorization`` header across a redirect; the environment proxy and its credentials are now re-resolved for each request, so a redirect that selects a different proxy no longer reuses the previous proxy's authentication -- by :user:`bdraco`. diff --git a/CHANGES/12831.bugfix.rst b/CHANGES/12831.bugfix.rst new file mode 100644 index 00000000000..bf460ffccac --- /dev/null +++ b/CHANGES/12831.bugfix.rst @@ -0,0 +1 @@ +Fixed :meth:`aiohttp.web.Response.write_eof` skipping ``Payload.close()`` when the body write was interrupted by an error or cancellation, for example when a client disconnects mid-response; the payload close hook now runs in a ``finally`` so a :class:`~aiohttp.payload.Payload` body always releases its resources -- by :user:`bdraco`. diff --git a/CHANGES/12832.bugfix.rst b/CHANGES/12832.bugfix.rst new file mode 100644 index 00000000000..00562fecd61 --- /dev/null +++ b/CHANGES/12832.bugfix.rst @@ -0,0 +1 @@ +Fixed the pure-Python HTTP parser not enforcing ``max_line_size`` on a chunk-size line when the whole line arrived in a single read; the limit was only applied to chunk-size metadata split across reads. The complete-line case is now checked too, matching the split-line behavior -- by :user:`bdraco`. diff --git a/CHANGES/12835.bugfix.rst b/CHANGES/12835.bugfix.rst new file mode 100644 index 00000000000..84a8ae00677 --- /dev/null +++ b/CHANGES/12835.bugfix.rst @@ -0,0 +1 @@ +Included the per-request ``server_hostname`` override in the :class:`~aiohttp.TCPConnector` connection pool key, so a pooled TLS connection is no longer reused for a request that sets ``server_hostname`` to a different value -- by :user:`bdraco`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index e65a8d6cba3..700e9db7f2e 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -781,7 +781,7 @@ cdef int cb_on_url(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data try: - if length > pyparser._max_line_size: + if len(pyparser._buf) + length > pyparser._max_line_size: status = pyparser._buf + at[:length] raise LineTooLong(status[:100] + b"...", pyparser._max_line_size) extend(pyparser._buf, at, length) @@ -796,7 +796,7 @@ cdef int cb_on_status(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data try: - if length > pyparser._max_line_size: + if len(pyparser._buf) + length > pyparser._max_line_size: reason = pyparser._buf + at[:length] raise LineTooLong(reason[:100] + b"...", pyparser._max_line_size) extend(pyparser._buf, at, length) diff --git a/aiohttp/client.py b/aiohttp/client.py index 4852b5b706b..4eb4e9454e2 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -549,10 +549,11 @@ async def _request( if proxy is None: proxy = self._default_proxy + resolved_proxy_headers: CIMultiDict[str] | None if proxy is None: - proxy_headers = None + resolved_proxy_headers = None else: - proxy_headers = self._prepare_headers(proxy_headers) + resolved_proxy_headers = self._prepare_headers(proxy_headers) try: proxy = URL(proxy) except ValueError as e: @@ -654,16 +655,17 @@ async def _request( if proxy is not None: proxy_ = URL(proxy) elif self._trust_env: + # Re-resolve per iteration; drop stale env-proxy auth so + # a redirect that switches proxies can't leak credentials. + resolved_proxy_headers = None with suppress(LookupError): proxy_, env_proxy_auth = await asyncio.to_thread( get_env_proxy_for_url, url ) - if env_proxy_auth is not None and ( - proxy_headers is None - or hdrs.PROXY_AUTHORIZATION not in proxy_headers - ): - proxy_headers = proxy_headers or CIMultiDict() - proxy_headers[hdrs.PROXY_AUTHORIZATION] = env_proxy_auth + if env_proxy_auth is not None: + resolved_proxy_headers = CIMultiDict( + {hdrs.PROXY_AUTHORIZATION: env_proxy_auth} + ) req = self._request_class( method, @@ -684,7 +686,7 @@ async def _request( session=self, ssl=ssl, server_hostname=server_hostname, - proxy_headers=proxy_headers, + proxy_headers=resolved_proxy_headers, traces=traces, trust_env=self.trust_env, ) diff --git a/aiohttp/client_middleware_digest_auth.py b/aiohttp/client_middleware_digest_auth.py index 8151dea5154..6c3e37f7c00 100644 --- a/aiohttp/client_middleware_digest_auth.py +++ b/aiohttp/client_middleware_digest_auth.py @@ -162,6 +162,15 @@ class DigestAuthMiddleware: - Includes replay attack protection with client nonce count tracking - Supports preemptive authentication per RFC 7616 Section 3.6 + Origin scoping: + The credentials are scoped to the origin of the first request the + middleware handles. A request to a different origin is passed through + untouched, so it never receives a digest response computed from those + credentials, unless that origin falls within a protection space the + anchor origin advertised through the RFC 7616 ``domain`` directive. Make + the first request through the middleware against the intended origin, as + the anchor is pinned to it and not reset for the life of the instance. + Standards compliance: - RFC 7616: HTTP Digest Access Authentication (primary reference) - RFC 2617: HTTP Authentication (deprecated by RFC 7616) @@ -198,6 +207,8 @@ def __init__( self._preemptive: bool = preemptive # Set of URLs defining the protection space self._protection_space: list[str] = [] + # Origin the credentials are scoped to; set on the first request. + self._origin: URL | None = None async def _encode(self, method: str, url: URL, body: Payload | Literal[b""]) -> str: """ @@ -447,6 +458,16 @@ async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: """Run the digest auth middleware.""" + # Credentials are scoped to the first request's origin. Other origins + # pass through untouched unless a challenge from the anchor origin + # advertised them via RFC 7616 domain; mirrors aiohttp stripping + # Authorization on cross-origin redirects. + origin = request.url.origin() + if self._origin is None: + self._origin = origin + elif origin != self._origin and not self._in_protection_space(request.url): + return await handler(request) + response = None for retry_count in range(2): # Apply authorization header if: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 310f744390d..8a060105146 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -180,6 +180,7 @@ class ConnectionKey(NamedTuple): ssl: SSLContext | bool | Fingerprint proxy: URL | None proxy_headers_hash: int | None # hash(CIMultiDict) + server_hostname: str | None = None class ClientResponse(HeadersMixin): @@ -818,6 +819,7 @@ def connection_key(self) -> ConnectionKey: self._ssl, None, None, + self.server_hostname, ), ) @@ -1055,6 +1057,7 @@ def connection_key(self) -> ConnectionKey: self._ssl, self.proxy, h, + self.server_hostname, ), ) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 6e70b3a28a2..90cf6e11046 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -28,6 +28,7 @@ ClientConnectorSSLError, ClientHttpProxyError, ClientProxyConnectionError, + InvalidUrlClientError, ServerFingerprintMismatch, UnixClientConnectorError, cert_errors, @@ -43,6 +44,7 @@ from .helpers import ( _SENTINEL, ceil_timeout, + is_canonical_ipv4_address, is_ip_address, sentinel, set_exception, @@ -1046,6 +1048,11 @@ async def _resolve_host( ) -> list[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): + # Reject legacy numeric IPv4 forms (e.g. 2130706433, 127.1) that + # socket would map onto an address, slipping past a connector-level + # policy that only sees the raw host. + if ":" not in host and not is_canonical_ipv4_address(host): + raise InvalidUrlClientError(host, "is not a canonical IPv4 address") return [ { "hostname": host, diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index b6b62c901af..8e945f3470d 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -37,6 +37,9 @@ _MIN_SCHEDULED_COOKIE_EXPIRATION = 100 _SIMPLE_COOKIE = SimpleCookie() +# Not persisted; the absolute deadline is saved instead. +_RELATIVE_EXPIRY_ATTRS = frozenset(("max-age", "expires")) + class CookieJar(AbstractCookieJar): """Implements cookie storage adhering to RFC 6265.""" @@ -133,21 +136,28 @@ def save(self, file_path: PathLike) -> None: :class:`str` or :class:`pathlib.Path` instance. """ file_path = pathlib.Path(file_path) - data: dict[str, dict[str, dict[str, str | bool]]] = {} + data: dict[str, dict[str, dict[str, str | bool | float]]] = {} for (domain, path), cookie in self._cookies.items(): key = f"{domain}|{path}" data[key] = {} for name, morsel in cookie.items(): - morsel_data: dict[str, str | bool] = { + morsel_data: dict[str, str | bool | float] = { "key": morsel.key, "value": morsel.value, "coded_value": morsel.coded_value, } - # Save all morsel attributes that have values + # Skip relative expiry; the absolute deadline is saved below. for attr in morsel._reserved: # type: ignore[attr-defined] + if attr in _RELATIVE_EXPIRY_ATTRS: + continue attr_val = morsel[attr] if attr_val: morsel_data[attr] = attr_val + # Persist or it reloads as a domain cookie and leaks to subdomains. + if (domain, name) in self._host_only_cookies: + morsel_data["host_only"] = True + if (exp := self._expirations.get((domain, path, name))) is not None: + morsel_data["expires_timestamp"] = exp data[key][name] = morsel_data # Cookie persistence may include authentication/session tokens. @@ -164,34 +174,33 @@ def save(self, file_path: PathLike) -> None: def load(self, file_path: PathLike) -> None: """Load cookies from a JSON file. + Replaces the current jar contents; loaded cookies pass through the + same acceptance rules as :meth:`update_cookies`. + :param file_path: Path to file from where cookies will be imported, :class:`str` or :class:`pathlib.Path` instance. """ file_path = pathlib.Path(file_path) with file_path.open(mode="r", encoding="utf-8") as f: data = json.load(f) - self._cookies = self._load_json_data(data) + self._load_json_data(data) def _load_json_data( - self, data: dict[str, dict[str, dict[str, str | bool]]] - ) -> defaultdict[tuple[str, str], SimpleCookie]: - """Load cookies from parsed JSON data.""" - cookies: defaultdict[tuple[str, str], SimpleCookie] = defaultdict(SimpleCookie) + self, data: dict[str, dict[str, dict[str, str | bool | float]]] + ) -> None: + """Replace contents, routing cookies through update_cookies().""" + self.clear() for compound_key, cookie_data in data.items(): domain, path = compound_key.split("|", 1) - key = (domain, path) for name, morsel_data in cookie_data.items(): morsel: Morsel[str] = Morsel() - morsel_key = morsel_data["key"] - morsel_value = morsel_data["value"] - morsel_coded_value = morsel_data["coded_value"] # Use __setstate__ to bypass validation, same pattern # used in _build_morsel and _cookie_helpers. morsel.__setstate__( # type: ignore[attr-defined] { - "key": morsel_key, - "value": morsel_value, - "coded_value": morsel_coded_value, + "key": morsel_data["key"], + "value": morsel_data["value"], + "coded_value": morsel_data["coded_value"], } ) # Restore morsel attributes @@ -202,8 +211,17 @@ def _load_json_data( "coded_value", ): morsel[attr] = morsel_data[attr] - cookies[key][name] = morsel - return cookies + # Drop the domain so update_cookies() re-marks it host-only. + if morsel_data.get("host_only"): + morsel["domain"] = "" + response_url = ( + URL.build(scheme="https", host=domain) if domain else URL() + ) + self.update_cookies({name: morsel}, response_url) + # Restore the absolute deadline; update_cookies() schedules none. + if (exp := morsel_data.get("expires_timestamp")) is not None: + self._expire_cookie(float(exp), domain, path, name) + self._do_expiration() def clear(self, predicate: ClearCookiePredicate | None = None) -> None: if predicate is None: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 55a5e01edcc..88580f00708 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -485,6 +485,28 @@ def is_ip_address(host: str | None) -> bool: return ":" in host or host.replace(".", "").isdigit() +def is_canonical_ipv4_address(host: str) -> bool: + """Check if host is a canonical dotted-quad IPv4 address. + + Rejects the legacy numeric forms that ``socket`` still accepts and + maps onto an address, e.g. ``2130706433``, ``017700000001``, ``127.1``. + """ + parts = host.split(".") + if len(parts) != 4: + return False + for part in parts: + # Each octet must be 1-3 ASCII digits; reject unicode digits + # (which ``str.isdigit`` accepts but ``int`` may not), octal + # leading zeros, and values above 255. + if not (1 <= len(part) <= 3) or not part.isascii() or not part.isdigit(): + return False + if part[0] == "0" and len(part) != 1: + return False + if int(part) > 255: + return False + return True + + _cached_current_datetime: int | None = None _cached_formatted_datetime = "" diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index cfd56c0997e..cd4342677aa 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -937,6 +937,10 @@ def feed_data( if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) if pos >= 0: + # Only chunk-size lines reach here; trailers enforce + # _max_field_size separately in PARSE_TRAILERS below. + if pos > self._max_line_size: + raise LineTooLong(chunk[:100] + b"...", self._max_line_size) i = chunk.find(CHUNK_EXT, 0, pos) if i >= 0: size_b = chunk[:i] # strip chunk-extensions diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 28280ed286c..72d26e607d7 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -543,14 +543,21 @@ def _read_nowait(self, n: int) -> bytes: """Read not more than n bytes, or whole buffer if n == -1""" self._timer.assert_timeout() - chunks = [] + if n == -1: + # Drain only chunks present now; _read_nowait_chunk() can + # re-entrantly resume_reading() and refill the buffer. + count = len(self._buffer) + if count == 1: + return self._read_nowait_chunk(-1) + return b"".join([self._read_nowait_chunk(-1) for _ in range(count)]) + + chunks: list[bytes] = [] while self._buffer: chunk = self._read_nowait_chunk(n) chunks.append(chunk) - if n != -1: - n -= len(chunk) - if n == 0: - break + n -= len(chunk) + if n == 0: + break return b"".join(chunks) if chunks else b"" diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index f231ca5d289..441d7f02900 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -688,8 +688,10 @@ async def write_eof(self, data: bytes = b"") -> None: if body is None or self._must_be_empty_body: await super().write_eof() elif isinstance(self._body, Payload): - await self._body.write(self._payload_writer) - await self._body.close() + try: + await self._body.write(self._payload_writer) + finally: + await self._body.close() await super().write_eof() else: await super().write_eof(cast(bytes, body)) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index b2e7b580c82..9bcf87ad49b 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -2364,6 +2364,16 @@ Utilities The server may still respond with a 401 status and ``stale=true`` if the nonce has expired, in which case the middleware will automatically retry with the new nonce. + **Origin scoping** + + The credentials are scoped to the origin of the first request the middleware + handles. A request to a different origin is passed through untouched, so it + never receives a digest response computed from those credentials, unless that + origin falls within a protection space the anchor origin advertised through + the RFC 7616 ``domain`` directive. Make the first request through the + middleware against the intended origin, as the anchor is pinned to it and not + reset for the life of the instance. + To disable preemptive authentication and require a 401 challenge for every request, set ``preemptive=False``:: @@ -2389,6 +2399,10 @@ Utilities .. versionadded:: 3.12 .. versionchanged:: 3.12.8 Added ``preemptive`` parameter to enable/disable preemptive authentication. + .. versionchanged:: 3.14.1 + Credentials are scoped to the origin of the first request the middleware + handles; other origins are passed through untouched unless covered by an + RFC 7616 ``domain`` directive from the anchor origin. .. class:: CookieJar(*, unsafe=False, quote_cookie=True, treat_as_secure_origin = []) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index d41d1cba2af..6fe7631bc75 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -722,6 +722,37 @@ async def handler(request: web.Request) -> web.Response: assert txt == "Test message" +async def test_server_hostname_override_not_reused( + aiohttp_server: AiohttpServer, +) -> None: + """A pooled TLS connection must not be reused for a different server_hostname.""" + trustme = pytest.importorskip("trustme") + + ca = trustme.CA() + cert = ca.issue_cert("first.example") + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + cert.configure_cert(server_ctx) + client_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) + ca.configure_trust(client_ctx) + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="ok") + + app = web.Application() + app.router.add_route("GET", "/", handler) + server = await aiohttp_server(app, ssl=server_ctx) + url = server.make_url("/") + + connector = aiohttp.TCPConnector(ssl=client_ctx, limit=1, limit_per_host=1) + async with aiohttp.ClientSession(connector=connector) as session: + async with session.get(url, server_hostname="first.example") as resp: + assert resp.status == 200 + await resp.read() + + with pytest.raises(aiohttp.ClientConnectorCertificateError): + await session.get(url, server_hostname="second.example") + + @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) diff --git a/tests/test_client_middleware_digest_auth.py b/tests/test_client_middleware_digest_auth.py index 7d0064cbe7f..ca1a9267b67 100644 --- a/tests/test_client_middleware_digest_auth.py +++ b/tests/test_client_middleware_digest_auth.py @@ -1198,6 +1198,176 @@ async def handler(request: Request) -> Response: ) # Second request - preemptive auth (entire origin) +async def test_does_not_answer_cross_origin_redirect_challenge( + aiohttp_server: AiohttpServer, +) -> None: + """A cross-origin redirect target must not receive a digest response. + + aiohttp strips the Authorization header on cross-origin redirects; the + digest middleware must not re-add one for the redirect target, otherwise + the configured credentials leak to an origin the caller never targeted. + """ + target_auth_headers: list[str | None] = [] + + async def target_handler(request: Request) -> Response: + auth_header = request.headers.get(hdrs.AUTHORIZATION) + target_auth_headers.append(auth_header) + assert auth_header is None + return Response( + status=401, + headers={ + hdrs.WWW_AUTHENTICATE: 'Digest realm="evil", nonce="cross-origin"' + }, + ) + + target_app = Application() + target_app.router.add_get("/", target_handler) + target_server = await aiohttp_server(target_app) + + async def source_handler(request: Request) -> Response: + return Response( + status=302, headers={hdrs.LOCATION: str(target_server.make_url("/"))} + ) + + source_app = Application() + source_app.router.add_get("/", source_handler) + source_server = await aiohttp_server(source_app) + + digest_auth = DigestAuthMiddleware("victim", "secret") + async with ( + ClientSession(middlewares=(digest_auth,)) as session, + session.get(source_server.make_url("/")) as response, + ): + await response.text() + + assert target_auth_headers == [None] + + +async def test_answers_same_origin_redirect_challenge( + aiohttp_server: AiohttpServer, +) -> None: + """A same-origin redirect that issues a challenge must still authenticate.""" + auth_headers: list[str | None] = [] + + async def handler(request: Request) -> Response: + if request.path == "/start": + return Response(status=302, headers={hdrs.LOCATION: "/protected"}) + auth_header = request.headers.get(hdrs.AUTHORIZATION) + auth_headers.append(auth_header) + if auth_header is None: + return Response( + status=401, + headers={hdrs.WWW_AUTHENTICATE: 'Digest realm="good", nonce="abc"'}, + ) + return Response(text="OK") + + app = Application() + app.router.add_get("/start", handler) + app.router.add_get("/protected", handler) + server = await aiohttp_server(app) + + digest_auth = DigestAuthMiddleware("user", "pass") + async with ( + ClientSession(middlewares=(digest_auth,)) as session, + session.get(server.make_url("/start")) as response, + ): + assert response.status == 200 + assert await response.text() == "OK" + + assert auth_headers[0] is None + assert auth_headers[1] is not None + assert auth_headers[1].startswith("Digest") + + +async def test_answers_cross_origin_within_domain_protection_space( + aiohttp_server: AiohttpServer, +) -> None: + """A different origin advertised via the ``domain`` directive is honored. + + RFC 7616 allows a challenge to define a protection space spanning other + servers through the ``domain`` directive. The anchor origin vouches for + those URIs, so preemptive auth to them is expected. + """ + other_auth_headers: list[str | None] = [] + + async def other_handler(request: Request) -> Response: + other_auth_headers.append(request.headers.get(hdrs.AUTHORIZATION)) + return Response(text="other") + + other_app = Application() + other_app.router.add_get("/", other_handler) + other_server = await aiohttp_server(other_app) + other_origin = str(other_server.make_url("/").origin()) + + async def anchor_handler(request: Request) -> Response: + if request.headers.get(hdrs.AUTHORIZATION) is None: + challenge = f'Digest realm="anchor", nonce="n1", domain="{other_origin}/"' + return Response(status=401, headers={hdrs.WWW_AUTHENTICATE: challenge}) + return Response(text="anchor") + + anchor_app = Application() + anchor_app.router.add_get("/", anchor_handler) + anchor_server = await aiohttp_server(anchor_app) + + digest_auth = DigestAuthMiddleware("user", "pass") + async with ClientSession(middlewares=(digest_auth,)) as session: + async with session.get(anchor_server.make_url("/")) as response: + assert response.status == 200 + async with session.get(other_server.make_url("/")) as response: + assert response.status == 200 + + assert other_auth_headers[0] is not None + assert other_auth_headers[0].startswith("Digest") + + +async def test_does_not_answer_cross_origin_challenge_without_redirect( + aiohttp_server: AiohttpServer, +) -> None: + """Origin scoping applies to any cross-origin request, not just redirects. + + After authenticating against the anchor origin, a direct request to a + different origin that issues its own challenge must not be answered with a + digest response computed from the configured credentials. + """ + other_auth_headers: list[str | None] = [] + + async def other_handler(request: Request) -> Response: + auth_header = request.headers.get(hdrs.AUTHORIZATION) + other_auth_headers.append(auth_header) + assert auth_header is None + return Response( + status=401, + headers={hdrs.WWW_AUTHENTICATE: 'Digest realm="evil", nonce="x"'}, + ) + + other_app = Application() + other_app.router.add_get("/", other_handler) + other_server = await aiohttp_server(other_app) + + async def anchor_handler(request: Request) -> Response: + if request.headers.get(hdrs.AUTHORIZATION) is None: + return Response( + status=401, + headers={hdrs.WWW_AUTHENTICATE: 'Digest realm="anchor", nonce="n1"'}, + ) + return Response(text="anchor") + + anchor_app = Application() + anchor_app.router.add_get("/", anchor_handler) + anchor_server = await aiohttp_server(anchor_app) + + digest_auth = DigestAuthMiddleware("user", "pass") + async with ClientSession(middlewares=(digest_auth,)) as session: + async with session.get(anchor_server.make_url("/")) as response: + assert response.status == 200 + async with session.get(other_server.make_url("/")) as response: + assert response.status == 401 + + # The other origin only ever saw the unauthenticated request; the + # middleware never answered its challenge. + assert other_auth_headers == [None] + + @pytest.mark.parametrize( ("status", "headers", "expected"), [ diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 29488d85493..ccf3b5d66a7 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -1755,6 +1755,22 @@ async def test_connection_key_without_proxy( await req._close() +async def test_connection_key_includes_server_hostname( + make_client_request: _RequestMaker, +) -> None: + """A server_hostname override must be part of the connection reuse key.""" + url = URL("https://127.0.0.1:8443/") + none_req = make_client_request("GET", url) + first = make_client_request("GET", url, server_hostname="first.example") + first_again = make_client_request("GET", url, server_hostname="first.example") + second = make_client_request("GET", url, server_hostname="second.example") + + assert first.connection_key.server_hostname == "first.example" + assert first.connection_key != none_req.connection_key + assert first.connection_key != second.connection_key + assert first.connection_key == first_again.connection_key + + def test_request_info_back_compat() -> None: """Test RequestInfo can be created without real_url.""" url = URL("http://example.com") diff --git a/tests/test_connector.py b/tests/test_connector.py index 0b1cbcff03e..259dfb495e5 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -32,6 +32,7 @@ web, ) from aiohttp.abc import AbstractResolver, ResolveResult +from aiohttp.client_exceptions import InvalidUrlClientError from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequestArgs, ConnectionKey from aiohttp.connector import ( @@ -1303,6 +1304,35 @@ async def test_tcp_connector_resolve_host() -> None: await conn.close() +async def test_tcp_connector_rejects_non_canonical_ipv4_alias() -> None: + """Legacy numeric IPv4 aliases must not bypass the configured resolver.""" + calls: list[str] = [] + + class _RecordingResolver(AbstractResolver): + async def resolve( + self, + host: str, + port: int = 0, + family: socket.AddressFamily = socket.AF_INET, + ) -> list[ResolveResult]: + assert False + + async def close(self) -> None: + """Close the resolver.""" + + conn = aiohttp.TCPConnector(resolver=_RecordingResolver()) + for alias in ("2130706433", "017700000001", "127.1"): + with pytest.raises(InvalidUrlClientError, match="canonical IPv4"): + await conn._resolve_host(alias, 8080) + + # Resolver is never consulted, and a canonical IP still short-circuits it. + assert calls == [] + res = await conn._resolve_host("127.0.0.1", 8080) + assert res[0]["host"] == "127.0.0.1" + assert calls == [] + await conn.close() + + @pytest.fixture def dns_response() -> Callable[[], Awaitable[list[str]]]: async def coro() -> list[str]: diff --git a/tests/test_cookiejar.py b/tests/test_cookiejar.py index 0693f41f617..0c24df71c0b 100644 --- a/tests/test_cookiejar.py +++ b/tests/test_cookiejar.py @@ -1,6 +1,7 @@ import datetime import heapq import itertools +import json import logging import os import stat @@ -1623,6 +1624,140 @@ def test_save_load_json_partitioned_cookies(tmp_path: Path) -> None: assert s["path"] == lo["path"] +def test_save_load_json_preserves_host_only_scope(tmp_path: Path) -> None: + """Verify save/load keeps host-only cookies off subdomains.""" + file_path = tmp_path / "host_only.json" + issuer = URL("https://auth.example.com/login") + subdomain = URL("https://sub.auth.example.com/") + + jar_save = CookieJar() + jar_save.update_cookies({"sid": "hostonly"}, response_url=issuer) + assert "sid" not in jar_save.filter_cookies(subdomain) + jar_save.save(file_path=file_path) + + jar_load = CookieJar() + jar_load.load(file_path=file_path) + + assert jar_load.host_only_cookies == frozenset({("auth.example.com", "sid")}) + assert "sid" not in jar_load.filter_cookies(subdomain) + assert "sid" in jar_load.filter_cookies(issuer) + + +def test_save_load_json_domain_cookie_still_matches_subdomain( + tmp_path: Path, +) -> None: + """Verify save/load keeps an explicit Domain cookie valid for subdomains.""" + file_path = tmp_path / "domain.json" + subdomain = URL("https://sub.example.com/") + + jar_save = CookieJar() + jar_save.update_cookies_from_headers( + ["sid=domaincookie; Domain=example.com"], URL("https://example.com/") + ) + jar_save.save(file_path=file_path) + + jar_load = CookieJar() + jar_load.load(file_path=file_path) + + assert jar_load.host_only_cookies == frozenset() + assert "sid" in jar_load.filter_cookies(subdomain) + + +def test_save_load_json_preserves_max_age_deadline(tmp_path: Path) -> None: + """Verify save/load restores the absolute deadline without resetting it.""" + file_path = tmp_path / "max_age.json" + url = URL("https://example.com/") + + jar_save = CookieJar() + jar_save.update_cookies_from_headers( + ["sid=x; Max-Age=3600; Domain=example.com"], url + ) + expirations = dict(jar_save._expirations) + jar_save.save(file_path=file_path) + + jar_load = CookieJar() + jar_load.load(file_path=file_path) + + # The deadline is restored as the original absolute time, not now + Max-Age. + assert dict(jar_load._expirations) == expirations + assert "sid" in jar_load.filter_cookies(url) + + +def test_save_load_json_drops_expired_cookie(tmp_path: Path) -> None: + """Verify a cookie whose persisted deadline is in the past is dropped on load.""" + file_path = tmp_path / "expired.json" + url = URL("https://example.com/") + + # Save a future-expiring cookie, then rewrite its persisted deadline to the + # past so the cookie survives save() and the drop happens on the load path. + jar_save = CookieJar() + jar_save.update_cookies_from_headers( + ["sid=x; Expires=Tue, 1 Jan 2999 12:00:00 GMT; Domain=example.com"], url + ) + jar_save.save(file_path=file_path) + data = json.loads(file_path.read_text()) + _, cookies = next(iter(data.items())) + cookies["sid"]["expires_timestamp"] = 0.0 + file_path.write_text(json.dumps(data)) + + jar_load = CookieJar() + jar_load.load(file_path=file_path) + + assert len(jar_load) == 0 + assert "sid" not in jar_load.filter_cookies(url) + + +def test_save_load_json_preserves_expires_deadline(tmp_path: Path) -> None: + """Verify a future Expires deadline survives a save/load roundtrip.""" + file_path = tmp_path / "expires.json" + url = URL("https://example.com/") + + jar_save = CookieJar() + jar_save.update_cookies_from_headers( + ["sid=x; Expires=Tue, 1 Jan 2999 12:00:00 GMT; Domain=example.com"], url + ) + expirations = dict(jar_save._expirations) + jar_save.save(file_path=file_path) + + jar_load = CookieJar() + jar_load.load(file_path=file_path) + + assert dict(jar_load._expirations) == expirations + assert "sid" in jar_load.filter_cookies(url) + + +def test_load_json_old_format_without_new_keys(tmp_path: Path) -> None: + """Verify a file written by an older version (no host_only/expires_timestamp) loads.""" + file_path = tmp_path / "old.json" + # Old schema: no host_only, no expires_timestamp; relative max-age morsel attr. + file_path.write_text( + json.dumps( + { + "example.com|/": { + "sid": { + "key": "sid", + "value": "x", + "coded_value": "x", + "domain": "example.com", + "max-age": "3600", + } + } + } + ) + ) + url = URL("https://example.com/") + + jar_load = CookieJar() + # No exception when the new keys are absent. + jar_load.load(file_path=file_path) + + # A host-only cookie saved without Domain by an older version had no domain + # field, so it now loads as a domain cookie (the documented migration loss). + assert "sid" in jar_load.filter_cookies(url) + # max-age is rescheduled from load time rather than an absolute deadline. + assert any(key[2] == "sid" for key in jar_load._expirations) + + def test_json_format_is_safe(tmp_path: Path) -> None: """Verify the JSON file format cannot execute code on load.""" import json diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 144c3677c9e..a499746607a 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,8 @@ import asyncio import datetime import gc +import ipaddress +import itertools import sys import weakref from collections.abc import Iterator @@ -230,6 +232,93 @@ def test_is_ip_address_invalid_type() -> None: helpers.is_ip_address(object()) # type: ignore[arg-type] +# ------------------------------- is_canonical_ipv4_address() --------------- + + +@pytest.mark.parametrize( + "host", + [ + "0.0.0.0", + "127.0.0.1", + "8.8.8.8", + "192.168.0.1", + "255.255.255.255", + ], +) +def test_is_canonical_ipv4_address_accepts_dotted_quad(host: str) -> None: + assert helpers.is_canonical_ipv4_address(host) + + +@pytest.mark.parametrize( + "host", + [ + "2130706433", # decimal integer form of 127.0.0.1 + "017700000001", # octal form of 127.0.0.1 + "127.1", # short-hand form of 127.0.0.1 + "127.0.1", # 3-part short-hand + "0177.0.0.1", # octal leading-zero octet + "01.2.3.4", # octal leading-zero octet + "256.0.0.1", # octet out of range + "999.0.0.1", # octet out of range + "1.2.3.4.5", # too many octets + "127.0.0.", # trailing dot / empty octet + "12³.0.0.1", # superscript digit (str.isdigit but not int) + "127.0.0.1", # full-width digits + "0xa.0.0.0", # hex octet + " 127.0.0.1", # leading whitespace + "127.0.0.1 ", # trailing whitespace + "example.com", # domain name + "", # empty + ], +) +def test_is_canonical_ipv4_address_rejects_non_canonical(host: str) -> None: + assert not helpers.is_canonical_ipv4_address(host) + + +def _ipaddress_accepts_ipv4(host: str) -> bool: + """Oracle: does the stdlib accept ``host`` as a canonical IPv4 address?""" + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + return False + return True + + +def test_is_canonical_ipv4_address_matches_stdlib() -> None: + """Prove equivalence with ``ipaddress.IPv4Address`` over a broad corpus. + + The helper is a fast hand-rolled substitute for the stdlib parser; this + exhaustively cross-checks the two agree on every combination of a set of + octet-like tokens covering the known edge cases (leading zeros, out of + range, empty, unicode digits, wrong octet count). + """ + tokens = [ + "0", + "1", + "9", + "10", + "99", + "255", + "256", + "999", + "00", + "01", + "0177", + "1234", + "", + "a", + "0x1", + "1", # full-width 1 + "1²", # trailing superscript + ] + for count in range(1, 5): + for parts in itertools.product(tokens, repeat=count): + host = ".".join(parts) + assert helpers.is_canonical_ipv4_address(host) == _ipaddress_accepts_ipv4( + host + ), host + + # ----------------------------------- TimeoutHandle ------------------- diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 6cb5dce5394..f95e4cfacf1 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1657,6 +1657,17 @@ def test_http_request_max_status_line_under_limit(parser: HttpRequestParser) -> assert msg.url == URL("/path" + path.decode()) +def test_http_request_max_status_line_fragmented( + parser: HttpRequestParser, +) -> None: + # Split an overlong request target across reads so that each callback + # fragment is under the limit but the accumulated target is not. + match = "400, message:\n Got more than 8190 bytes when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): + parser.feed_data(b"GET /" + b"a" * 8000) + parser.feed_data(b"a" * 8000 + b" HTTP/1.1\r\nHost: a\r\n\r\n") + + def test_http_response_parser_utf8(response: HttpResponseParser) -> None: text = "HTTP/1.1 200 Ok\r\nx-test:тест\r\n\r\n".encode() @@ -1738,6 +1749,17 @@ def test_http_response_parser_status_line_under_limit( assert msg.reason == reason.decode() +def test_http_response_parser_status_line_too_long_fragmented( + response: HttpResponseParser, +) -> None: + # Split an overlong reason phrase across reads so that each callback + # fragment is under the limit but the accumulated reason is not. + match = "400, message:\n Got more than 8190 bytes when reading" + with pytest.raises(http_exceptions.LineTooLong, match=match): + response.feed_data(b"HTTP/1.1 200 " + b"a" * 8000) + response.feed_data(b"a" * 8000 + b"\r\n\r\n") + + def test_http_response_parser_bad_version(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(b"HT/11 200 Ok\r\n\r\n") @@ -2377,6 +2399,59 @@ async def test_parse_chunked_payload_size_error( p.feed_data(b"blah\r\n") assert isinstance(out.exception(), http_exceptions.TransferEncodingError) + async def test_chunked_chunk_size_line_too_long( + self, protocol: BaseProtocol + ) -> None: + """A complete oversized chunk-size line is rejected with LineTooLong.""" + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, chunked=True, headers_parser=HeadersParser(), max_line_size=32 + ) + size_line = b"1;" + b"a" * 4096 + b"\r\n" + with pytest.raises(http_exceptions.LineTooLong): + p.feed_data(size_line) + + async def test_chunked_chunk_size_line_within_limit( + self, protocol: BaseProtocol + ) -> None: + """A small chunk-size line still parses when max_line_size is low.""" + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, chunked=True, headers_parser=HeadersParser(), max_line_size=32 + ) + p.feed_data(b"1\r\nx\r\n0\r\n\r\n") + assert out.is_eof() + assert b"x" == b"".join(out._buffer) + + async def test_chunked_chunk_size_line_at_limit( + self, protocol: BaseProtocol + ) -> None: + """A chunk-size line of exactly max_line_size bytes is accepted (>, not >=).""" + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, chunked=True, headers_parser=HeadersParser(), max_line_size=32 + ) + # "1;" + 30 * "a" is exactly 32 bytes before the CRLF. + size_line = b"1;" + b"a" * 30 + assert len(size_line) == 32 + p.feed_data(size_line + b"\r\nx\r\n0\r\n\r\n") + assert out.is_eof() + assert b"x" == b"".join(out._buffer) + + async def test_chunked_chunk_size_line_one_over_limit( + self, protocol: BaseProtocol + ) -> None: + """A chunk-size line one byte over max_line_size is rejected.""" + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, chunked=True, headers_parser=HeadersParser(), max_line_size=32 + ) + # "1;" + 31 * "a" is 33 bytes before the CRLF. + size_line = b"1;" + b"a" * 31 + assert len(size_line) == 33 + with pytest.raises(http_exceptions.LineTooLong): + p.feed_data(size_line + b"\r\nx\r\n0\r\n\r\n") + async def test_parse_chunked_payload_size_data_mismatch( self, protocol: BaseProtocol ) -> None: diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 761be8ed1de..eefe51db251 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest +from multidict import CIMultiDict from pytest_aiohttp import AiohttpRawServer, AiohttpServer from pytest_mock import MockerFixture from yarl import URL @@ -908,6 +909,58 @@ async def test_proxy_from_env_http_without_auth_from_wrong_netrc( assert "Proxy-Authorization" not in proxy.request.headers +async def test_proxy_from_env_auth_scoped_to_redirect_selected_proxy( + aiohttp_raw_server: AiohttpRawServer, + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + # Redirect from an authenticated http_proxy to an HTTPS target served by a + # separate https_proxy must not leak the first proxy's Proxy-Authorization + # onto the second proxy's CONNECT request. + auth_header = aiohttp.encode_basic_auth("user", "pass") + http_proxy_requests: list[CIMultiDict[str]] = [] + https_proxy_requests: list[CIMultiDict[str]] = [] + + async def http_proxy_handler(request: web.Request) -> web.Response: + http_proxy_requests.append(CIMultiDict(request.headers)) + return web.Response( + status=302, headers={"Location": "https://attacker.example/secret"} + ) + + async def https_proxy_handler(request: web.Request) -> web.Response: + https_proxy_requests.append(CIMultiDict(request.headers)) + return web.Response(status=502) + + http_proxy = await aiohttp_raw_server(http_proxy_handler) + https_proxy = await aiohttp_raw_server(https_proxy_handler) + + for name in ( + "http_proxy", + "https_proxy", + "all_proxy", + "no_proxy", + "HTTP_PROXY", + "HTTPS_PROXY", + "ALL_PROXY", + "NO_PROXY", + ): + monkeypatch.delenv(name, raising=False) + netrc_file = tmp_path / "empty_netrc" + netrc_file.write_text("") + http_proxy_url = http_proxy.make_url("/").with_user("user").with_password("pass") + monkeypatch.setenv("http_proxy", str(http_proxy_url)) + monkeypatch.setenv("https_proxy", str(https_proxy.make_url("/"))) + monkeypatch.setenv("NETRC", str(netrc_file)) + + with pytest.raises(aiohttp.ClientHttpProxyError): + await get_request(url="http://victim.example/redirect", trust_env=True) + + assert len(http_proxy_requests) == 1 + assert http_proxy_requests[0]["Proxy-Authorization"] == auth_header + assert len(https_proxy_requests) == 1 + assert "Proxy-Authorization" not in https_proxy_requests[0] + + @pytest.mark.xfail async def test_proxy_from_env_https( proxy_test_server: Callable[[], Awaitable[mock.Mock]], mocker: MockerFixture diff --git a/tests/test_streams.py b/tests/test_streams.py index e11218680d4..fce10a4ec14 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1709,3 +1709,31 @@ def resume_reading() -> None: protocol.resume_reading.assert_called() assert protocol._reading_paused is False + + +async def test_readany_does_not_drain_reentrant_refill( + protocol: mock.Mock, +) -> None: + """A single readany() must not reassemble data fed re-entrantly. + + Draining below the low water mark resumes reading, which can synchronously + refill the buffer (e.g. decompressing another chunk). Joining that refill in + one call would reassemble an unbounded body. + """ + loop = asyncio.get_running_loop() + stream = streams.StreamReader(protocol, limit=4, loop=loop) + + refills = [b"second", b"third"] + + def resume_reading() -> None: + if refills: + stream.feed_data(refills.pop(0)) + + protocol.resume_reading.side_effect = resume_reading + + stream.feed_data(b"first") + + # Popping "first" refills "second", but this readany() returns only "first". + assert await stream.readany() == b"first" + assert await stream.readany() == b"second" + assert await stream.readany() == b"third" diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 2dd5bc3a2ac..c760bff2f19 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -4,7 +4,9 @@ import pathlib import socket import sys +import zlib from collections.abc import AsyncIterator, Awaitable, Callable, Generator +from contextlib import suppress from typing import NoReturn from unittest import mock @@ -27,7 +29,8 @@ from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.compression_utils import ZLibBackend, ZLibCompressObjProtocol from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING -from aiohttp.helpers import HeadersDictProxy +from aiohttp.helpers import DEFAULT_CHUNK_SIZE, HeadersDictProxy +from aiohttp.streams import StreamReader from aiohttp.typedefs import Handler, Middleware from aiohttp.web_protocol import RequestHandler @@ -1714,6 +1717,65 @@ async def handler(request: web.Request) -> web.StreamResponse: resp.release() +@pytest.mark.parametrize("decompressed_size", [4 * 1024 * 1024, 32 * 1024 * 1024]) +async def test_unread_compressed_body_drain_is_bounded( + aiohttp_server: AiohttpServer, + monkeypatch: pytest.MonkeyPatch, + decompressed_size: int, +) -> None: + """Draining an unread compressed body stays bounded by the read buffer. + + A handler that rejects before reading still drains the payload during + lingering close; a small compressed body must not force a large transient + allocation (a deflate-bomb style DoS). + """ + drain_reads: list[int] = [] + drained = asyncio.Event() + readany = StreamReader.readany + + async def record_readany(self: StreamReader) -> bytes: + data = await readany(self) + assert data + drain_reads.append(len(data)) + drained.set() + return data + + monkeypatch.setattr(StreamReader, "readany", record_readany) + + async def handler(request: web.Request) -> web.Response: + return web.Response(status=401) + + app = web.Application(client_max_size=1024) + app.router.add_post("/", handler) + server = await aiohttp_server(app) + + body = zlib.compress(b"a" * decompressed_size) + assert len(body) < decompressed_size + head = ( + b"POST / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Encoding: deflate\r\n" + b"Content-Length: %d\r\n" + b"Connection: keep-alive\r\n\r\n" + ) % len(body) + + reader, writer = await asyncio.open_connection(server.host, server.port) + try: + writer.write(head + body) + await writer.drain() + status_line = await asyncio.wait_for(reader.readline(), 5) + assert status_line.startswith(b"HTTP/1.1 401 ") + await asyncio.wait_for(drained.wait(), 5) + finally: + writer.close() + with suppress(ConnectionResetError, BrokenPipeError): + await writer.wait_closed() + + # Bounded by the buffer, not the decompressed size. + assert max(drain_reads) <= 3 * DEFAULT_CHUNK_SIZE + assert max(drain_reads) < decompressed_size + + async def test_app_max_client_size(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await request.post() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 3dee8d3a1b1..cc564cf832f 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1,3 +1,4 @@ +import asyncio import collections.abc import datetime import gzip @@ -19,7 +20,7 @@ from aiohttp.helpers import ETag, HeadersDictProxy from aiohttp.http_writer import StreamWriter, _serialize_headers from aiohttp.multipart import BodyPartReader, MultipartWriter -from aiohttp.payload import BytesPayload, StringPayload +from aiohttp.payload import BytesPayload, Payload, StringPayload from aiohttp.test_utils import make_mocked_request from aiohttp.typedefs import LooseHeaders @@ -1274,6 +1275,76 @@ async def test_consecutive_write_eof() -> None: writer.write_eof.assert_called_once_with(data) +class _ClosingPayload(Payload): + """Payload test double that records whether close() ran.""" + + def __init__(self) -> None: + super().__init__(None) + self.close_called = False + self.started = asyncio.Event() + self.release = asyncio.Event() + self.fail = False + + async def write(self, writer: AbstractStreamWriter) -> None: + self.started.set() + if self.fail: + raise ConnectionResetError("client gone") + await self.release.wait() + + async def close(self) -> None: + self.close_called = True + await super().close() + + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + assert False + + +async def test_write_eof_closes_payload_on_success() -> None: + writer = mock.create_autospec(AbstractStreamWriter, spec_set=True, instance=True) + req = make_request("GET", "/", writer=writer) + payload = _ClosingPayload() + payload.release.set() + resp = web.Response(body=payload) + + await resp.prepare(req) + await resp.write_eof() + + assert payload.close_called + assert writer.write_eof.called + + +async def test_write_eof_closes_payload_on_write_error() -> None: + writer = mock.create_autospec(AbstractStreamWriter, spec_set=True, instance=True) + req = make_request("GET", "/", writer=writer) + payload = _ClosingPayload() + payload.fail = True + resp = web.Response(body=payload) + + await resp.prepare(req) + with pytest.raises(ConnectionResetError): + await resp.write_eof() + + assert payload.close_called + assert not writer.write_eof.called + + +async def test_write_eof_closes_payload_on_cancel() -> None: + writer = mock.create_autospec(AbstractStreamWriter, spec_set=True, instance=True) + req = make_request("GET", "/", writer=writer) + payload = _ClosingPayload() + resp = web.Response(body=payload) + + await resp.prepare(req) + task = asyncio.ensure_future(resp.write_eof()) + await payload.started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert payload.close_called + assert not writer.write_eof.called + + def test_set_text_with_content_type() -> None: resp = web.Response() resp.content_type = "text/html"