Split a generator into chunks without pre-walking it

blueFast picture blueFast · Jul 2, 2014 · Viewed 19.5k times · Source

(This question is related to this one and this one, but those are pre-walking the generator, which is exactly what I want to avoid)

I would like to split a generator in chunks. The requirements are:

  • do not pad the chunks: if the number of remaining elements is less than the chunk size, the last chunk must be smaller.
  • do not walk the generator beforehand: computing the elements is expensive, and it must only be done by the consuming function, not by the chunker
  • which means, of course: do not accumulate in memory (no lists)

I have tried the following code:

def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        yield head(i, size)

# Sample generator: the real data is much more complex, and expensive to compute
els = xrange(7)

for n, chunk in enumerate(chunks(els, 3)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

And this somehow works:

Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
^CTraceback (most recent call last):
  File "xxxx.py", line 15, in <module>
    for el in chunk:
  File "xxxx.py", line 2, in head
    for cnt, el in enumerate(iterable):
KeyboardInterrupt

Buuuut ... it never stops (I have to press ^C) because of the while True. I would like to stop that loop whenever the generator has been consumed, but I do not know how to detect that situation. I have tried raising an Exception:

class NoMoreData(Exception):
    pass

def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break
    if cnt == 0 : raise NoMoreData()

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        try:
            yield head(i, size)
        except NoMoreData:
            break

# Sample generator: the real data is much more complex, and expensive to compute    
els = xrange(7)

for n, chunk in enumerate(chunks(els, 2)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

But then the exception is only raised in the context of the consumer, which is not what I want (I want to keep the consumer code clean)

Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
Traceback (most recent call last):
  File "xxxx.py", line 22, in <module>
    for el in chunk:
  File "xxxx.py", line 9, in head
    if cnt == 0 : raise NoMoreData
__main__.NoMoreData()

How can I detect that the generator is exhausted in the chunks function, without walking it?

Answer

tobias_k picture tobias_k · Jul 2, 2014

One way would be to peek at the first element, if any, and then create and return the actual generator.

def head(iterable, max=10):
    first = next(iterable)      # raise exception when depleted
    def head_inner():
        yield first             # yield the extracted first element
        for cnt, el in enumerate(iterable):
            yield el
            if cnt + 1 >= max:  # cnt + 1 to include first
                break
    return head_inner()

Just use this in your chunk generator and catch the StopIteration exception like you did with your custom exception.


Update: Here's another version, using itertools.islice to replace most of the head function, and a for loop. This simple for loop in fact does exactly the same thing as that unwieldy while-try-next-except-break construct in the original code, so the result is much more readable.

def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:    # stops when iterator is depleted
        def chunk():          # construct generator for next chunk
            yield first       # yield element from for loop
            for more in islice(iterator, size - 1):
                yield more    # yield more elements from the iterator
        yield chunk()         # in outer generator, yield next chunk

And we can get even shorter than that, using itertools.chain to replace the inner generator:

def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:
        yield chain([first], islice(iterator, size - 1))