Vectorized prefix sum

Kragen Javier Sitaker, 2017-07-19 (5 minutes)

You can apply the parallel prefix sum algorithm in a serial, vectorized way using a logarithmic number of Numpy’s vector operations. The parallel prefix sum algorithm puts the vector elements into the leaves of a complete N-ary tree (here I use N=2) and propagates partial prefix sums first up and then back down this tree.

Here you have the sum of 16 numbers in four elementwise-parallel operations using strided arrays:

>>> import numpy
>>> x = numpy.array([1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8])
>>> a = x[::2] + x[1::2]
>>> a
array([ 3,  7, 11, 15,  3,  7, 11, 15])
>>> b = a[::2] + a[1::2]
>>> b
array([10, 26, 10, 26])
>>> c = b[::2] + b[1::2]
>>> c
array([36, 36])
>>> d = c[::2] + c[1::2]
>>> d
array([72])

You can then use reverse the ordering to propagate the prefix sum back down to the leaves. The 72 in d is of course the prefix sum at the end of the right half of the tree, and c already contains the prefix sum of the left half of the tree. For reasons that will become somewhat less obscure below, we can transpose and ravel to perfect-shuffle them into the right order:

>>> cp = numpy.array((c[::2] + numpy.concatenate(([0], d))[:-1], d)).T.ravel()
>>> cp
array([36, 72])

The corresponding full prefix sums at the points represented by b are half elements of this cp and the other half composed from the even-indexed elements of b added to the previous element, or 0 if none.

>>> bp = numpy.array((b[::2] + numpy.concatenate(([0], cp))[:-1], cp)).T.ravel()
>>> bp
array([10, 36, 46, 72])

Then we can do the same trick to get an expanded version of a:

>>> ap = numpy.array((a[::2] + numpy.concatenate(([0], bp))[:-1], bp)).T.ravel()
>>> ap
array([ 3, 10, 21, 36, 39, 46, 57, 72])

And similarly to get the whole prefix sum:

>>> xp = numpy.array((x[::2] + numpy.concatenate(([0], ap))[:-1], ap)).T.ravel()
>>> xp
array([ 1,  3,  6, 10, 15, 21, 28, 36, 37, 39, 42, 46, 51, 57, 64, 72])

A thing to note about this expression is that it has only two O(N) operations in it: the prepending of the 0 and the addition. The stride-indexing, transposition, and raveling operators are all constant-time.

In a sense, this is not very interesting, because numpy already has a prefix-sum operator:

>>> numpy.cumsum(x)
array([ 1,  3,  6, 10, 15, 21, 28, 36, 37, 39, 42, 46, 51, 57, 64, 72])

But the above transformation can be generalized to any monoid, not just integer addition, and in particular to function composition, which means it can in theory be used to generate the sequence produced by any definite loop that can be expressed with these operations. It’s a little bit tricky in that the monoid elements are functions representing the action of an iteration of the loop and their composition takes the place of addition in the above calculations. If you’re going to represent them with numpy array elements, they need to be in some sense constant-space, and you may need many arrays.

As a useless and less trivial but still comprehensible example, consider the action taken by a loop decoding a digit string in some base b:

result = 0
for i in range(k):
    result = b * result + digits[i]

The action taken by any n iterations of the loop is to multiply the previous result by bⁿ and add some number m to it. To compose two such actions, (n₀, m₀) followed by (n₁, m₁), we compute (n₀ + n₁, m₀ · bⁿ1 + m₁); processing a number m merely produces (1, m). So we need one array (at each level) for n and one for m. Moving up the tree:

>>> na = numpy.ones(len(x))
>>> ma = x
>>> nb = na[::2] + na[1::2]
>>> nb
array([ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.])
>>> mb = ma[::2] * 10**na[1::2] + ma[1::2]
>>> mb
array([ 12.,  34.,  56.,  78.,  12.,  34.,  56.,  78.])
>>> nc = nb[::2] + nb[1::2]
>>> mc = mb[::2] * 10**nb[1::2] + mb[1::2]
>>> nc, mc
(array([ 4.,  4.,  4.,  4.]), array([ 1234.,  5678.,  1234.,  5678.]))
>>> nd = nc[::2] + nc[1::2]
>>> md = mc[::2] * 10**nc[1::2] + mc[1::2]
>>> nd, md
(array([ 8.,  8.]), array([ 12345678.,  12345678.]))
>>> ne = nd[::2] + nd[1::2]
>>> me = md[::2] * 10**nd[1::2] + md[1::2]
>>> me
array([  1.23456781e+15])

If we move back down the tree, we get all the intermediate results. This is somewhat tricky, though conceptually straightforward, and is left as an exercise to the reader, or to me at a later date maybe.

Note that everything in the above has used only the monoid operation and the identity element 0. But we don’t actually need the 0; the algorithm is still applicable to semigroups lacking an identity element. A practical example of this is finding the minimum or maximum of the elements seen so far.

Topics