GeneratorReader: don't copy so much

this was exposed when dealing with things that yield in very big chunks
potentially (e.g. brotli bombs)

tests are more directly on the GeneratorReader itself now rather than
integrating this with particular genators-under-test.
This commit is contained in:
Klaas van Schelven
2025-11-07 23:48:06 +01:00
parent aab062a11e
commit 26f327a257
2 changed files with 51 additions and 61 deletions

View File

@@ -56,30 +56,31 @@ def brotli_generator(input_stream, chunk_size=DEFAULT_CHUNK_SIZE):
class GeneratorReader:
"""Read from a generator as from a file-like object."""
def __init__(self, generator):
self.generator = generator
self.unread = b""
self.buffer = bytearray()
def read(self, size=None):
if size is None:
for chunk in self.generator:
self.unread += chunk
result = self.unread
self.unread = b""
self.buffer.extend(chunk)
result = bytes(self.buffer)
self.buffer.clear()
return result
while size > len(self.unread):
while len(self.buffer) < size:
try:
chunk = next(self.generator)
if chunk == b"":
if not chunk:
break
self.unread += chunk
self.buffer.extend(chunk)
except StopIteration:
break
self.unread, result = self.unread[size:], self.unread[:size]
result = bytes(self.buffer[:size])
del self.buffer[:size]
return result

View File

@@ -43,55 +43,31 @@ class StreamsTestCase(RegularTestCase):
def test_compress_decompress_gzip(self):
with open(__file__, 'rb') as f:
myself_times_ten = f.read() * 10
plain_stream = io.BytesIO(myself_times_ten)
compressed_stream = io.BytesIO(compress_with_zlib(plain_stream, WBITS_PARAM_FOR_GZIP))
result = b""
reader = GeneratorReader(zlib_generator(compressed_stream, WBITS_PARAM_FOR_GZIP))
while True:
chunk = reader.read(3)
result += chunk
if chunk == b"":
break
self.assertEqual(myself_times_ten, result)
self.assertEqual(myself_times_ten, reader.read())
def test_compress_decompress_deflate(self):
with open(__file__, 'rb') as f:
myself_times_ten = f.read() * 10
plain_stream = io.BytesIO(myself_times_ten)
compressed_stream = io.BytesIO(compress_with_zlib(plain_stream, WBITS_PARAM_FOR_DEFLATE))
result = b""
reader = GeneratorReader(zlib_generator(compressed_stream, WBITS_PARAM_FOR_DEFLATE))
while True:
chunk = reader.read(3)
result += chunk
if chunk == b"":
break
self.assertEqual(myself_times_ten, result)
self.assertEqual(myself_times_ten, reader.read())
def test_compress_decompress_brotli(self):
with open(__file__, 'rb') as f:
myself_times_ten = f.read() * 10
compressed_stream = io.BytesIO(brotli.compress(myself_times_ten))
result = b""
reader = GeneratorReader(brotli_generator(compressed_stream))
while True:
chunk = reader.read(3)
result += chunk
if chunk == b"":
break
self.assertEqual(myself_times_ten, result)
self.assertEqual(myself_times_ten, reader.read())
def test_decompress_brotli_tiny_bomb(self):
# by picking something "sufficiently large" we can ensure all three code paths in brotli_generator are taken,
@@ -99,29 +75,11 @@ class StreamsTestCase(RegularTestCase):
# side)
compressed_stream = io.BytesIO(brotli.compress(b"\x00" * 15_000_000))
result = b""
reader = GeneratorReader(brotli_generator(compressed_stream))
while True:
chunk = reader.read(3)
result += chunk
if chunk == b"":
break
self.assertEqual(b"\x00" * 15_000_000, result)
def test_compress_decompress_read_none(self):
with open(__file__, 'rb') as f:
myself_times_ten = f.read() * 10
plain_stream = io.BytesIO(myself_times_ten)
compressed_stream = io.BytesIO(compress_with_zlib(plain_stream, WBITS_PARAM_FOR_DEFLATE))
result = b""
reader = GeneratorReader(zlib_generator(compressed_stream, WBITS_PARAM_FOR_DEFLATE))
result = reader.read(None)
self.assertEqual(myself_times_ten, result)
size = 0
generator = brotli_generator(compressed_stream)
for chunk in generator:
size += len(chunk)
self.assertEqual(15_000_000, size)
def test_max_data_reader(self):
stream = io.BytesIO(b"hello" * 100)
@@ -160,6 +118,37 @@ class StreamsTestCase(RegularTestCase):
with self.assertRaises(ValueError):
writer.write(b"hellohello")
def test_generator_reader(self):
def generator():
yield b"hello "
yield b"I am "
yield b"a generator"
reader = GeneratorReader(generator())
self.assertEqual(b"hel", reader.read(3))
self.assertEqual(b"lo ", reader.read(3))
self.assertEqual(b"I a", reader.read(3))
self.assertEqual(b"m a", reader.read(3))
self.assertEqual(b" generator", reader.read(None))
def test_generator_reader_performance(self):
# at least one test directly for GeneratorReader; doubles as a regression test for performance issue that showed
# up when the underlying generator yielded relatively big chunks and the read() size was small. should run
# easily under a second.
def yielding_big_chunks():
yield b"x" * 500_000
read = []
reader = GeneratorReader(yielding_big_chunks())
while True:
chunk = reader.read(1)
if chunk == b"":
break
read.append(chunk)
@override_settings(DEBUG_CSRF=True)
class CSRFViewsTestCase(DjangoTestCase):