diff --git a/s3proxy/handlers/objects/get.py b/s3proxy/handlers/objects/get.py index b3d038c..901f095 100644 --- a/s3proxy/handlers/objects/get.py +++ b/s3proxy/handlers/objects/get.py @@ -470,13 +470,12 @@ async def _fetch_and_decrypt_frame( frame_ciphertext_size: int, dek: bytes, ) -> bytes: + # No per-frame memory reservation here: concurrent streaming GETs are bounded + # at admission (the request-level reservation is held for the whole stream + # lifetime), so the working set is O(concurrent streams), not O(frames). A + # nested per-frame acquire would deadlock against that held reservation. expected_size = frame_ciphertext_size - additional = max(0, expected_size * 2 - MAX_BUFFER_SIZE) - extra_reserved = 0 try: - if additional > 0: - extra_reserved = await concurrency.try_acquire_memory(additional) - resp = await client.get_object(bucket, key, f"bytes={ct_start}-{ct_end}") async with resp["Body"] as body: ciphertext = await body.read() @@ -515,9 +514,6 @@ async def _fetch_and_decrypt_frame( f"range {ct_start}-{ct_end} invalid" ) from e raise - finally: - if extra_reserved > 0: - await concurrency.release_memory(extra_reserved) async def _fetch_and_decrypt_part( self, @@ -544,21 +540,13 @@ async def _fetch_and_decrypt_part( self._validate_ciphertext_range(bucket, key, part_num, 0, ct_end, actual_size) - part_size = part_meta.ciphertext_size - additional = max(0, part_size * 2 - MAX_BUFFER_SIZE) - extra_reserved = 0 - try: - if additional > 0: - extra_reserved = await concurrency.try_acquire_memory(additional) - - resp = await client.get_object(bucket, key, f"bytes={ct_start}-{ct_end}") - async with resp["Body"] as body: - ciphertext = await body.read() - decrypted = crypto.decrypt(ciphertext, dek) - return decrypted[off_start : off_end + 1] - finally: - if extra_reserved > 0: - await concurrency.release_memory(extra_reserved) + # See _fetch_and_decrypt_frame: stream concurrency is bounded at admission, + # so no per-frame reservation is taken here. + resp = await client.get_object(bucket, key, f"bytes={ct_start}-{ct_end}") + async with resp["Body"] as body: + ciphertext = await body.read() + decrypted = crypto.decrypt(ciphertext, dek) + return decrypted[off_start : off_end + 1] def _build_response_headers(self, resp: dict, last_modified: str | None) -> dict[str, str]: return self._build_headers( diff --git a/s3proxy/request_handler.py b/s3proxy/request_handler.py index 5ac7388..9691a7f 100644 --- a/s3proxy/request_handler.py +++ b/s3proxy/request_handler.py @@ -9,7 +9,7 @@ import structlog from botocore.exceptions import ClientError from fastapi import HTTPException, Request -from fastapi.responses import PlainTextResponse +from fastapi.responses import PlainTextResponse, StreamingResponse from structlog.stdlib import BoundLogger from . import concurrency @@ -47,6 +47,17 @@ def _is_dashboard_path(request: Request, path: str) -> bool: return path == prefix or path.startswith(prefix + "/") +async def _release_after_stream(iterator, reserved: int): + """Wrap a streaming body so its memory reservation is released only after the + stream is fully sent (or the client disconnects), not when the handler returns. + """ + try: + async for chunk in iterator: + yield chunk + finally: + await concurrency.release_memory(reserved) + + def _needs_body_for_signature(headers: dict[str, str]) -> bool: """Body is needed only when x-amz-content-sha256 is absent. @@ -127,6 +138,16 @@ async def handle_proxy_request( response = await _handle_proxy_request_impl(request, handler, verifier) if response is not None: status_code = response.status_code + # A StreamingResponse sends its body AFTER this handler returns. Releasing + # the reservation in the finally below frees it before a byte is sent, + # leaving the stream ungoverned -- and each concurrent stream holds ~one + # decrypted frame in the transport send buffer, so a flood of concurrent + # GETs accumulates frames and OOMs the pod while the limiter reads ~budget. + # Hold the reservation for the whole stream lifetime so the limiter bounds + # how many streaming GETs run at once (admission control). + if reserved_memory > 0 and isinstance(response, StreamingResponse): + response.body_iterator = _release_after_stream(response.body_iterator, reserved_memory) + reserved_memory = 0 return response except HTTPException as e: status_code = e.status_code diff --git a/tests/unit/test_streaming_get_reservation.py b/tests/unit/test_streaming_get_reservation.py new file mode 100644 index 0000000..4d7d1b4 --- /dev/null +++ b/tests/unit/test_streaming_get_reservation.py @@ -0,0 +1,62 @@ +"""A streaming GET must hold its memory reservation for the whole stream. + +A StreamingResponse sends its body AFTER the request handler returns. The +handler used to release the GET's memory reservation in its finally -- i.e. +before a single byte was streamed -- so concurrent streaming GETs ran +ungoverned. Each one holds ~one decrypted frame in the transport send buffer, +so a flood accumulated frames and OOMed the pod while the limiter read ~budget +(reproduced locally: a 90-concurrent multipart GET flood at a 512Mi cap with a +64MB budget OOMKilled the pod, exit 137, 0/180 succeeded; with the reservation +held for the stream lifetime it peaks ~325MiB and completes 180/180). + +_release_after_stream wraps the body iterator so the reservation is released +only when the stream is exhausted or the consumer stops early. These tests pin +that timing: the reservation stays held while streaming and is released exactly +once at teardown. +""" + +import pytest + +from s3proxy import concurrency +from s3proxy.request_handler import _release_after_stream + +MB = 1024 * 1024 + + +@pytest.mark.asyncio +async def test_reservation_held_until_stream_exhausted(): + limiter = concurrency._default + limiter.set_memory_limit(64) + limiter.active_bytes = 0 + reserved = await limiter.try_acquire(8 * MB) + assert limiter.active_bytes == reserved > 0 + + async def source(): + for i in range(3): + # Reservation must still be held while the body is being sent. + assert limiter.active_bytes == reserved + yield f"chunk{i}".encode() + + chunks = [c async for c in _release_after_stream(source(), reserved)] + assert chunks == [b"chunk0", b"chunk1", b"chunk2"] + # Released exactly once after the stream finished. + assert limiter.active_bytes == 0 + + +@pytest.mark.asyncio +async def test_reservation_released_on_early_consumer_exit(): + limiter = concurrency._default + limiter.set_memory_limit(64) + limiter.active_bytes = 0 + reserved = await limiter.try_acquire(8 * MB) + + async def source(): + for i in range(100): + yield bytes(i) + + wrapped = _release_after_stream(source(), reserved) + async for _ in wrapped: + break # client disconnects after one chunk + await wrapped.aclose() + # Reservation released even though the stream was abandoned early. + assert limiter.active_bytes == 0