Efficient Weighted Sampling

Here’s a really simple and cheap way to do importance sampling that I came across a few months ago (while learning about particle filters, incicentally). It’s simple enough that I have no idea how I went so long without ever knowing about it! In case anyone else is in the same boat - let me walk you through it.

Let’s say you have an array of N values, and an associated array of N weights and you want to randomly sample your input values such that the probability of each value being chosen is proportional to its weight.

There are many ways of doing that, but I’d like to focus on a particularly efficient approach (it’s O(N) with a relatively small constant factor).

Imagine laying out 8 input values in a row like this, where their weights correspond to their width in the diagram. alt text

Now, to select 8 weighted samples we could line up another row of evenly spaced “cells” below the first one, corresponding to our chosen samples (modulo my poor mspaint skills).

alt text alt text

To pick our samples, we could just select a point in the center of each of of the cells in the lower row and pick the input value it lines up with on the top row. Unfortunately that would give you the same ouput samples every time you call it which isn’t very random. Instead, let’s randomly pick a point within each sample’s cell, and pick the output sample be the value that corresponds to that point in the upper row.

alt text alt text

If we look at each of the stars in the lower row and line them up with the weighted cells above, you’ll see that we’d end up choosing the first and second cells once, skip the third cell, sample the fourth cell twice, etc. It’s pretty easy to see that if we were to do this a bunch of times with different random offsets for the cells each time, each input value would get selected a number of times proportional to their width (i.e. their weight).

This is called stratified sampling (in particular, it’s usually called “jittered sampling”) and chances are you’ve seen something like this before in other contexts. It’s not a pure weighted sample because it enforces a certain “spread” of the randomly selected points. For example, if the weights where all the same then each value would get sampled exactly once.

Here’s some code that implements this in C++

 1 2 3 4 5 6 7 8 9101112131415161718192021222324252627282930
void Resample(float* values, float* weights, float* outputs, int inputN, int outN)
{
    float sumWeights = 0.0f;
    for (int i = 0; i < inputN; ++i)
    {
        sumWeights += weights[i];
    }
    
    float sampleWidth = sumWeights / outN;
    
    // Note that upper end of this range is exclusive. That matters.
    std::default_random_engine generator;
    std::uniform_real_distribution<float> rnd(0, sampleWidth);

    int outputSampleIx = -1;
    float weightSoFar = 0.0f;
    for (int i = 0; i < outN; ++i)
    {
        // How far is this sample from the origin?
        float sampleDist = i*sampleWidth + rnd(generator);      

        // Find which sample to output. Just walk up the samples until the sum
        // of the weights is > to the distance of the current sample
        while (sampleDist >= weightSoFar) // BEWARE: there is a bug here, see below.
        {           
            weightSoFar += weights[++outputSampleIx];
        }           
        outputs[i] = values[outputSampleIx];
    }
}

First, we compute the total “weight” of our input samples, then we compute the width of the sample “cells” from that. Next, we loop through our output samples and for each one we compute a random offset with its “cell” and then simply walk along the input weights until the sum of weights is greater than this sample points. We don’t have to re-walk the input weights each time (continue at the previous output index), which makes this O(n). Note that we can have a different number of output samples to input samples.

This all works in principle, but unfortunately floating point precision trips us up. Very occasionally the sampleDist for the selected sample point ends up being greater than the sum of all weights due to accumulated floating point precision errors, which will cause the inner while loop above to index out of bounds in the array. To fix this, we modify the while loop condition to make sure we never go beyond the the bounds of the array:

1
while (sampleDist >= weightSoFar && outputSampleIx + 1 < inputN)

This avoids causing a crash, but does mean our samplling strategy is ever so slightly biased towards the last sample in the input array. A simple way to drastically reduce this effect is to use doubles for all the accumulation variables.

Now, the above code works fine and is pretty efficient (especially if you already know the sum of the weights and can avoid the initial pass over the input data to compute it). But maybe we can do better? If the ordering of the input values don’t have any partcular structure, then we can avoid selecting a random offset for each output sample, and instead selecting just a single random offset that gets reused for all sampling cells. This is called Stochastic Universal Sampling. Obviously this is even more efficient since the most expensive operation in the previous version is likely to be the random number generation, and we’ve just moved that out of the loop. Now that the this “offset” is constant we can also move it out of the loop by subtracting it from the weightSoFar accumulator..

Here’s the final version of the code.

 1 2 3 4 5 6 7 8 9101112131415161718192021222324252627
void Resample(float* input, float* weights, float* outputs, int inputN, int outN)
{
    float sumWeights = 0.0f;
    for (int i = 0; i < inputN; ++i)
    {
        sumWeights += weights[i];
    }
    
    float sampleWidth = sumWeights / outN;
    std::default_random_engine generator;
    std::uniform_real_distribution<float> rnd(0, sampleWidth);
    int outputSampleIx = -1;
    float weightSoFar = -rnd(generator);
    for (int i = 0; i < outN; ++i)
    {
        // How far is this sample from the origin (minus offset)?       
        float sampleDist = i*sampleWidth;

        // Find which sample to output. Just walk up the samples until the sum
        // of the weights is > to the distance of the current sample
        while (sampleDist >= weightSoFar && outputSampleIx + 1 < inputN)
        {
            weightSoFar += weights[++outputSampleIx];
        }           
        outputs[i] = input[outputSampleIx]; 
    }
}
Comment Form is loading comments...