Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit fb9c976

Browse files
committed
yield requests from generator
1 parent ad68e91 commit fb9c976

File tree

9 files changed

+106
-54
lines changed

9 files changed

+106
-54
lines changed

google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,12 @@ async def append(
359359
attempt_count = 0
360360

361361
def send_and_recv_generator(
362-
requests: List[BidiWriteObjectRequest],
362+
requests_generator,
363363
state: dict[str, _WriteState],
364364
metadata: Optional[List[Tuple[str, str]]] = None,
365365
):
366366
async def generator():
367367
nonlocal attempt_count
368-
nonlocal requests
369368
attempt_count += 1
370369
resp = None
371370
async with self._lock:
@@ -402,16 +401,33 @@ async def generator():
402401
write_state.bytes_sent = write_state.persisted_size
403402
write_state.bytes_since_last_flush = 0
404403

405-
requests = strategy.generate_requests(state)
406-
407-
num_requests = len(requests)
408-
for i, chunk_req in enumerate(requests):
409-
if i == num_requests - 1:
410-
chunk_req.state_lookup = True
411-
chunk_req.flush = True
404+
# Process requests from the generator
405+
# Strategy handles state_lookup and flush on the last request,
406+
# so we just stream requests directly
407+
for chunk_req in requests_generator:
408+
# Check if this is an open/state-lookup request (no checksummed_data)
409+
if chunk_req.state_lookup and not chunk_req.checksummed_data:
410+
# This is an open request - send it and get response
411+
await self.write_obj_stream.send(chunk_req)
412+
resp = await self.write_obj_stream.recv()
413+
414+
# Update state from open response
415+
if resp:
416+
if resp.persisted_size is not None:
417+
self.persisted_size = resp.persisted_size
418+
write_state.persisted_size = resp.persisted_size
419+
self.offset = self.persisted_size
420+
if resp.write_handle:
421+
self.write_handle = resp.write_handle
422+
write_state.write_handle = resp.write_handle
423+
continue
424+
425+
# This is a data request - send it
412426
await self.write_obj_stream.send(chunk_req)
413427

428+
# Get final response from the last request (which has state_lookup=True)
414429
resp = await self.write_obj_stream.recv()
430+
415431
if resp:
416432
if resp.persisted_size is not None:
417433
self.persisted_size = resp.persisted_size

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ async def download_ranges(
377377
attempt_count = 0
378378

379379
def send_ranges_and_get_bytes(
380-
requests: List[_storage_v2.ReadRange],
380+
requests_generator,
381381
state: Dict[str, Any],
382382
metadata: Optional[List[Tuple[str, str]]] = None,
383383
):
@@ -387,7 +387,7 @@ async def generator():
387387

388388
if attempt_count > 1:
389389
logger.info(
390-
f"Resuming download (attempt {attempt_count - 1}) for {len(requests)} ranges."
390+
f"Resuming download (attempt {attempt_count - 1})."
391391
)
392392

393393
async with lock:
@@ -436,17 +436,28 @@ async def generator():
436436
)
437437
self._is_stream_open = True
438438

439-
pending_read_ids = {r.read_id for r in requests}
439+
# Stream requests directly without materializing
440+
pending_read_ids = set()
441+
current_batch = []
442+
443+
for read_range in requests_generator:
444+
pending_read_ids.add(read_range.read_id)
445+
current_batch.append(read_range)
446+
447+
# Send batch when it reaches max size
448+
if len(current_batch) >= _MAX_READ_RANGES_PER_BIDI_READ_REQUEST:
449+
await self.read_obj_str.send(
450+
_storage_v2.BidiReadObjectRequest(read_ranges=current_batch)
451+
)
452+
current_batch = []
440453

441-
# Send Requests
442-
for i in range(
443-
0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
444-
):
445-
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
454+
# Send remaining partial batch
455+
if current_batch:
446456
await self.read_obj_str.send(
447-
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
457+
_storage_v2.BidiReadObjectRequest(read_ranges=current_batch)
448458
)
449459

460+
# Receive responses
450461
while pending_read_ids:
451462
response = await self.read_obj_str.recv()
452463
if response is None:

google/cloud/storage/_experimental/asyncio/retry/base_strategy.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,20 @@ class _BaseResumptionStrategy(abc.ABC):
2727
"""
2828

2929
@abc.abstractmethod
30-
def generate_requests(self, state: Any) -> Iterable[Any]:
31-
"""Generates the next batch of requests based on the current state.
30+
def generate_requests(self, state: Any):
31+
"""Generates requests based on the current state as a generator.
3232
3333
This method is called at the beginning of each retry attempt. It should
34-
inspect the provided state object and generate the appropriate list of
35-
request protos to send to the server. For example, a read strategy
36-
would use this to implement "Smarter Resumption" by creating smaller
37-
`ReadRange` requests for partially downloaded ranges. For bidi-writes,
38-
it will set the `write_offset` field to the persisted size received
39-
from the server in the next request.
34+
inspect the provided state object and yield request protos to send to
35+
the server. For example, a read strategy would use this to implement
36+
"Smarter Resumption" by creating smaller `ReadRange` requests for
37+
partially downloaded ranges. For bidi-writes, it will set the
38+
`write_offset` field to the persisted size received from the server
39+
in the next request.
40+
41+
This is a generator that yields requests incrementally rather than
42+
returning them all at once, allowing for better memory efficiency
43+
and on-demand generation.
4044
4145
:type state: Any
4246
:param state: An object containing all the state needed for the

google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ async def execute(self, initial_state: Any, retry_policy):
5050
state = initial_state
5151

5252
async def attempt():
53-
requests = self._strategy.generate_requests(state)
54-
stream = self._send_and_recv(requests, state)
53+
requests_generator = self._strategy.generate_requests(state)
54+
stream = self._send_and_recv(requests_generator, state)
5555
try:
5656
async for response in stream:
5757
self._strategy.update_state_from_response(response, state)

google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@ def __init__(
4949
class _ReadResumptionStrategy(_BaseResumptionStrategy):
5050
"""The concrete resumption strategy for bidi reads."""
5151

52-
def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]:
52+
def generate_requests(self, state: Dict[str, Any]):
5353
"""Generates new ReadRange requests for all incomplete downloads.
5454
55+
This is a generator that yields requests one at a time for incomplete
56+
downloads, allowing for better memory efficiency and incremental processing.
57+
5558
:type state: dict
5659
:param state: A dictionary mapping a read_id to its corresponding
5760
_DownloadState object.
5861
"""
59-
pending_requests = []
6062
download_states: Dict[int, _DownloadState] = state["download_states"]
6163

6264
for read_id, read_state in download_states.items():
@@ -74,8 +76,7 @@ def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]
7476
read_length=new_length,
7577
read_id=read_id,
7678
)
77-
pending_requests.append(new_request)
78-
return pending_requests
79+
yield new_request
7980

8081
def update_state_from_response(
8182
self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any]

google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,38 @@ class _WriteResumptionStrategy(_BaseResumptionStrategy):
6565

6666
def generate_requests(
6767
self, state: Dict[str, Any]
68-
) -> List[storage_type.BidiWriteObjectRequest]:
68+
):
6969
"""Generates BidiWriteObjectRequests to resume or continue the upload.
7070
71-
This method is not applicable for `open` methods.
71+
This method is a generator that yields requests one at a time,
72+
allowing for incremental sending and better memory efficiency.
73+
74+
On retry/redirect, yields a state_lookup request first to get the current
75+
persisted state from the server before sending data requests.
76+
77+
The last data request is always yielded with state_lookup=True and flush=True
78+
to ensure the server persists the final data and returns the updated state.
7279
"""
7380
write_state: _WriteState = state["write_state"]
7481

75-
requests = []
82+
# If this is a retry/redirect, yield a state lookup request first
83+
# This allows the sender to get current persisted_size before proceeding
84+
if write_state.routing_token or write_state.bytes_sent > write_state.persisted_size:
85+
# Yield an open/state-lookup request with no data
86+
yield storage_type.BidiWriteObjectRequest(state_lookup=True)
87+
7688
# The buffer should already be seeked to the correct position (persisted_size)
7789
# by the `recover_state_on_failure` method before this is called.
7890
while not write_state.is_finalized:
7991
chunk = write_state.user_buffer.read(write_state.chunk_size)
8092

81-
# End of File detection
8293
if not chunk:
8394
break
8495

96+
# Peek to see if this is the last chunk. This is safe because both
97+
# io.BytesIO and BufferedReader (used in file uploads) support peek().
98+
is_last_chunk = not getattr(write_state.user_buffer, "peek", lambda n: b"")(1)
99+
85100
checksummed_data = storage_type.ChecksummedData(content=chunk)
86101
checksum = google_crc32c.Checksum(chunk)
87102
checksummed_data.crc32c = int.from_bytes(checksum.digest(), "big")
@@ -102,8 +117,11 @@ def generate_requests(
102117
# reset counter after marking flush
103118
write_state.bytes_since_last_flush = 0
104119

105-
requests.append(request)
106-
return requests
120+
if is_last_chunk:
121+
request.flush = True
122+
request.state_lookup = True
123+
124+
yield request
107125

108126
def update_state_from_response(
109127
self, response: storage_type.BidiWriteObjectResponse, state: Dict[str, Any]

tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ async def mock_send_and_recv(*args, **kwargs):
139139
@pytest.mark.asyncio
140140
async def test_execute_fails_immediately_on_non_retriable_error(self):
141141
mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy)
142+
mock_strategy.generate_requests.return_value = iter([])
142143

143-
async def mock_send_and_recv(*args, **kwargs):
144+
async def mock_send_and_recv(strategy, state, **kwargs):
145+
strategy.generate_requests(state)
144146
if False:
145147
yield
146148
raise exceptions.PermissionDenied("Auth error")

tests/unit/asyncio/retry/test_reads_resumption_strategy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_generate_requests_single_incomplete(self):
109109
read_state = self._add_download(_READ_ID, offset=0, length=100)
110110
read_state.bytes_written = 20
111111

112-
requests = self.strategy.generate_requests(self.state)
112+
requests = list(self.strategy.generate_requests(self.state))
113113

114114
self.assertEqual(len(requests), 1)
115115
self.assertEqual(requests[0].read_offset, 20)
@@ -124,7 +124,7 @@ def test_generate_requests_multiple_incomplete(self):
124124

125125
self._add_download(read_id2, offset=200, length=100)
126126

127-
requests = self.strategy.generate_requests(self.state)
127+
requests = list(self.strategy.generate_requests(self.state))
128128

129129
self.assertEqual(len(requests), 2)
130130
requests.sort(key=lambda r: r.read_id)
@@ -145,7 +145,7 @@ def test_generate_requests_read_to_end_resumption(self):
145145
read_state = self._add_download(_READ_ID, offset=0, length=0)
146146
read_state.bytes_written = 500
147147

148-
requests = self.strategy.generate_requests(self.state)
148+
requests = list(self.strategy.generate_requests(self.state))
149149

150150
self.assertEqual(len(requests), 1)
151151
self.assertEqual(requests[0].read_offset, 500)
@@ -156,7 +156,7 @@ def test_generate_requests_with_complete(self):
156156
read_state = self._add_download(_READ_ID)
157157
read_state.is_complete = True
158158

159-
requests = self.strategy.generate_requests(self.state)
159+
requests = list(self.strategy.generate_requests(self.state))
160160
self.assertEqual(len(requests), 0)
161161

162162
def test_generate_requests_multiple_mixed_states(self):
@@ -170,7 +170,7 @@ def test_generate_requests_multiple_mixed_states(self):
170170
s3 = self._add_download(3, offset=200, length=100)
171171
s3.bytes_written = 0
172172

173-
requests = self.strategy.generate_requests(self.state)
173+
requests = list(self.strategy.generate_requests(self.state))
174174

175175
self.assertEqual(len(requests), 2)
176176
requests.sort(key=lambda r: r.read_id)
@@ -180,7 +180,7 @@ def test_generate_requests_multiple_mixed_states(self):
180180

181181
def test_generate_requests_empty_state(self):
182182
"""Test generating requests with an empty state."""
183-
requests = self.strategy.generate_requests(self.state)
183+
requests = list(self.strategy.generate_requests(self.state))
184184
self.assertEqual(len(requests), 0)
185185

186186
# --- Update State and response processing Tests ---

tests/unit/asyncio/retry/test_writes_resumption_strategy.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_generate_requests_initial_chunking(self, strategy):
4848
write_state = _WriteState(chunk_size=3, user_buffer=mock_buffer)
4949
state = {"write_state": write_state}
5050

51-
requests = strategy.generate_requests(state)
51+
requests = list(strategy.generate_requests(state))
5252

5353
# Expected: 4 requests (3, 3, 3, 1)
5454
assert len(requests) == 4
@@ -85,7 +85,7 @@ def test_generate_requests_resumption(self, strategy):
8585

8686
state = {"write_state": write_state}
8787

88-
requests = strategy.generate_requests(state)
88+
requests = list(strategy.generate_requests(state))
8989

9090
# Since 4 bytes are done, we expect remaining 6 bytes: [4 bytes, 2 bytes]
9191
assert len(requests) == 2
@@ -104,7 +104,7 @@ def test_generate_requests_empty_file(self, strategy):
104104
write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer)
105105
state = {"write_state": write_state}
106106

107-
requests = strategy.generate_requests(state)
107+
requests = list(strategy.generate_requests(state))
108108

109109
assert len(requests) == 0
110110

@@ -115,7 +115,7 @@ def test_generate_requests_checksum_verification(self, strategy):
115115
write_state = _WriteState(chunk_size=10, user_buffer=mock_buffer)
116116
state = {"write_state": write_state}
117117

118-
requests = strategy.generate_requests(state)
118+
requests = list(strategy.generate_requests(state))
119119

120120
expected_crc = google_crc32c.Checksum(chunk_data).digest()
121121
expected_int = int.from_bytes(expected_crc, "big")
@@ -130,7 +130,7 @@ def test_generate_requests_flush_logic_exact_interval(self, strategy):
130130
)
131131
state = {"write_state": write_state}
132132

133-
requests = strategy.generate_requests(state)
133+
requests = list(strategy.generate_requests(state))
134134

135135
# Request index 1 (4 bytes total) should have flush=True
136136
assert requests[0].flush is False
@@ -155,7 +155,7 @@ def test_generate_requests_flush_logic_none_interval(self, strategy):
155155
)
156156
state = {"write_state": write_state}
157157

158-
requests = strategy.generate_requests(state)
158+
requests = list(strategy.generate_requests(state))
159159

160160
for req in requests:
161161
assert req.flush is False
@@ -169,7 +169,7 @@ def test_generate_requests_flush_logic_data_less_than_interval(self, strategy):
169169
)
170170
state = {"write_state": write_state}
171171

172-
requests = strategy.generate_requests(state)
172+
requests = list(strategy.generate_requests(state))
173173

174174
# Total 5 bytes < 10 bytes interval
175175
for req in requests:
@@ -184,7 +184,7 @@ def test_generate_requests_honors_finalized_state(self, strategy):
184184
write_state.is_finalized = True
185185
state = {"write_state": write_state}
186186

187-
requests = strategy.generate_requests(state)
187+
requests = list(strategy.generate_requests(state))
188188
assert len(requests) == 0
189189

190190
@pytest.mark.asyncio
@@ -217,7 +217,7 @@ async def test_generate_requests_after_failure_and_recovery(self, strategy):
217217
# 2. bytes_sent should track persisted_size (4)
218218
assert write_state.bytes_sent == 4
219219

220-
requests = strategy.generate_requests(state)
220+
requests = list(strategy.generate_requests(state))
221221

222222
# Remaining data from offset 4 to 16 (12 bytes total)
223223
# Chunks: [4-8], [8-12], [12-16]

0 commit comments

Comments
 (0)