Select k random elements from a list whose elements have weights

If the sampling is with replacement, you can use this algorithm (implemented here in Python):

import random

items = [(10, "low"),
         (100, "mid"),
         (890, "large")]

def weighted_sample(items, n):
    total = float(sum(w for w, v in items))
    i = 0
    w, v = items[0]
    while n:
        x = total * (1 - random.random() ** (1.0 / n))
        total -= x
        while x > w:
            x -= w
            i += 1
            w, v = items[i]
        w -= x
        yield v
        n -= 1

This is O(n + m) where m is the number of items.

Why does this work? It is based on the following algorithm:

def n_random_numbers_decreasing(v, n):
    """Like reversed(sorted(v * random() for i in range(n))),
    but faster because we avoid sorting."""
    while n:
        v *= random.random() ** (1.0 / n)
        yield v
        n -= 1

The function weighted_sample is just this algorithm fused with a walk of the items list to pick out the items selected by those random numbers.

This in turn works because the probability that n random numbers 0..v will all happen to be less than z is P = (z/v)n. Solve for z, and you get z = vP1/n. Substituting a random number for P picks the largest number with the correct distribution; and we can just repeat the process to select all the other numbers.

If the sampling is without replacement, you can put all the items into a binary heap, where each node caches the total of the weights of all items in that subheap. Building the heap is O(m). Selecting a random item from the heap, respecting the weights, is O(log m). Removing that item and updating the cached totals is also O(log m). So you can pick n items in O(m + n log m) time.

(Note: "weight" here means that every time an element is selected, the remaining possibilities are chosen with probability proportional to their weights. It does not mean that elements appear in the output with a likelihood proportional to their weights.)

Here's an implementation of that, plentifully commented:

import random

class Node:
    # Each node in the heap has a weight, value, and total weight.
    # The total weight, self.tw, is self.w plus the weight of any children.
    __slots__ = ['w', 'v', 'tw']
    def __init__(self, w, v, tw):
        self.w, self.v, self.tw = w, v, tw

def rws_heap(items):
    # h is the heap. It's like a binary tree that lives in an array.
    # It has a Node for each pair in `items`. h[1] is the root. Each
    # other Node h[i] has a parent at h[i>>1]. Each node has up to 2
    # children, h[i<<1] and h[(i<<1)+1].  To get this nice simple
    # arithmetic, we have to leave h[0] vacant.
    h = [None]                          # leave h[0] vacant
    for w, v in items:
        h.append(Node(w, v, w))
    for i in range(len(h) - 1, 1, -1):  # total up the tws
        h[i>>1].tw += h[i].tw           # add h[i]'s total to its parent
    return h

def rws_heap_pop(h):
    gas = h[1].tw * random.random()     # start with a random amount of gas

    i = 1                     # start driving at the root
    while gas >= h[i].w:      # while we have enough gas to get past node i:
        gas -= h[i].w         #   drive past node i
        i <<= 1               #   move to first child
        if gas >= h[i].tw:    #   if we have enough gas:
            gas -= h[i].tw    #     drive past first child and descendants
            i += 1            #     move to second child
    w = h[i].w                # out of gas! h[i] is the selected node.
    v = h[i].v

    h[i].w = 0                # make sure this node isn't chosen again
    while i:                  # fix up total weights
        h[i].tw -= w
        i >>= 1
    return v

def random_weighted_sample_no_replacement(items, n):
    heap = rws_heap(items)              # just make a heap...
    for i in range(n):
        yield rws_heap_pop(heap)        # and pop n items off it.

I know this is a very old question, but I think there's a neat trick to do this in O(n) time if you apply a little math!

The exponential distribution has two very useful properties.

  1. Given n samples from different exponential distributions with different rate parameters, the probability that a given sample is the minimum is equal to its rate parameter divided by the sum of all rate parameters.

  2. It is "memoryless". So if you already know the minimum, then the probability that any of the remaining elements is the 2nd-to-min is the same as the probability that if the true min were removed (and never generated), that element would have been the new min. This seems obvious, but I think because of some conditional probability issues, it might not be true of other distributions.

Using fact 1, we know that choosing a single element can be done by generating these exponential distribution samples with rate parameter equal to the weight, and then choosing the one with minimum value.

Using fact 2, we know that we don't have to re-generate the exponential samples. Instead, just generate one for each element, and take the k elements with lowest samples.

Finding the lowest k can be done in O(n). Use the Quickselect algorithm to find the k-th element, then simply take another pass through all elements and output all lower than the k-th.

A useful note: if you don't have immediate access to a library to generate exponential distribution samples, it can be easily done by: -ln(rand())/weight


If the sampling is with replacement, use the roulette-wheel selection technique (often used in genetic algorithms):

  1. sort the weights
  2. compute the cumulative weights
  3. pick a random number in [0,1]*totalWeight
  4. find the interval in which this number falls into
  5. select the elements with the corresponding interval
  6. repeat k times

alt text

If the sampling is without replacement, you can adapt the above technique by removing the selected element from the list after each iteration, then re-normalizing the weights so that their sum is 1 (valid probability distribution function)