Very Fast Reservoir Sampling by Erik Erlandson

今天正式的把这个抽样算法加到Apache Flink里了, 这个算法是Erik Erlandson在他的博客上公布的大概是迄今为止最快速的抽样算法的. 这个算法采用了流行的gap distribution的方法抽样, 有效的在减少cpu使用的情况下, 减少了内存的占用, 通过生成抽样之间的gap, 进行近似随机抽样.

在他的博客http://erikerlandson.github.io/blog/2015/08/17/the-reservoir-sampling-gap-distribution/中,  证明了抽样可以通过生成gap实现随机抽样, 大大减少了随机数的生成时间和占用的内存, 实际应用下, 可以大大增加整体系统的运行效率. 他的博客中, 有基于Bernoulli Distribution(https://github.com/apache/flink/blob/master/flink-java/src/main/java/org/apache/flink/api/java/sampling/BernoulliSampler.java)和Poisson Distribution(https://github.com/apache/flink/blob/master/flink-java/src/main/java/org/apache/flink/api/java/sampling/PoissonSampler.java)的两种实现, 在Apache Flink中, Bernoulli分布实现了非replacement的抽样, 而Poisson分布实现了replacement的抽样.

这个新的优化算法, 使用了几何分布(https://en.wikipedia.org/wiki/Geometric_distribution)的思想, 对于样品大小远远小于数据大小 (1000000倍以上)的情况下, 样品的抽样率近似于P = R/j, R是样品集合大小, j是当前数据的总数. 这个算法有两部分组成.

  1. 作者定义了一个阈值T, T=4R, R是样品集合的大小. 这里的4是由随机数的好坏确定的. 确定的方法很复杂. 如果数据量小于阈值T, 则使用传统的水塘抽样. 这里使用水塘抽样的原因是,
    1. 如果当前的数据量小, 则生成gap会大大影响抽样结果的分布. 因为这个算法的gap是通过累计分布最后达到均匀分布的, 少量的gap会产生极大的误差. 通过实验表明, KS-test下, 如果样品大小是100,数据量是100000, 那么与随机抽样比较, 误差远远超过了 KS-test的容忍值D.
    2. 在数据量小的情况下, 生成少量随机数并不会对系统产生太大的负担, 这种trade-off是可以接受的.
    3. 水塘抽样的分布和数据量无关, 所以即使抽样大小和数据量很接近, 比如R=100, j=10000, 也可以完美的进行随机抽样.
  2. 当数据超过T, 开始使用几何分布的抽样,我们用如下的方式计算gap:
    1. 首先算出p, 即理想的随机抽样下,  当前样品被抽中的概率, 显而易见, p = R/j
    2. 1-p 就是没有被抽中的概率, 那么生成u, 一个(0,1)的float, 做Math.log(u)/Math.log(1-p) <—这个公式的意义是, 在1-p的周围取gap.
  3. 然后通过省略gap个元素, 取下一个, 就可以实现随机抽样.

这个算法非常依赖于抽样大小和数据量之间的关系, 如果前者不是远远小于后者, 那么gap的大小会导致分布极度不均, 这种不均会使样品的误差很大.

下面是code:

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.api.java.sampling;

import com.google.common.base.Preconditions;
import org.apache.flink.annotation.Internal;
import org.apache.flink.util.XORShiftRandom;

import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;

/**
 * A in memory implementation of Very Fast Reservoir Sampling. The algorithm works well then the size of streaming data is much larger than size of reservoir.
 * The algorithm runs in random sampling with P(R/j) where in R is the size of sampling and j is the current index of streaming data.
 * The algorithm consists of two part:
 * 	(1) Before the size of streaming data reaches threshold, it uses regular reservoir sampling
 * 	(2) After the size of streaming data reaches threshold, it uses geometric distribution to generate the approximation gap
 * 		to skip data, and size of gap is determined by  geometric distribution with probability p = R/j
 *
 *  Thanks to Erik Erlandson who is the author of this algorithm and help me with implementation.
 * @param <T> The type of sample.
 * @see <a href="http://erikerlandson.github.io/blog/2015/11/20/very-fast-reservoir-sampling/">Very Fast Reservoir Sampling</a>
 */
@Internal
public class VeryFastReservoirSampler<T> extends DistributedRandomSampler<T> {

	private final Random random;
	// THRESHOLD is a tuning parameter for choosing sampling method according to the fraction.
	private final int THRESHOLD = 4 * super.numSamples;

	/**
	 * Create a new sampler with reservoir size and a supplied random number generator.
	 *
	 * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
	 * @param random     Instance of random number generator for sampling.
	 */
	public VeryFastReservoirSampler(int numSamples, Random random) {
		super(numSamples);
		Preconditions.checkArgument(numSamples >= 0, "numSamples should be non-negative.");
		this.random = random;
	}

	/**
	 * Create a new sampler with reservoir size and a default random number generator.
	 *
	 * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
	 */
	public VeryFastReservoirSampler(int numSamples) {
		this(numSamples, new XORShiftRandom());
	}

	/**
	 * Create a new sampler with reservoir size and the seed for random number generator.
	 *
	 * @param numSamples Maximum number of samples to retain in reservoir, must be non-negative.
	 * @param seed       Random number generator seed.
	 */
	public VeryFastReservoirSampler(int numSamples, long seed) {

		this(numSamples, new XORShiftRandom(seed));
	}

	@Override
	public Iterator<IntermediateSampleData<T>> sampleInPartition(Iterator<T> input) {
		if (numSamples == 0) {
			return EMPTY_INTERMEDIATE_ITERABLE;
		}
		PriorityQueue<IntermediateSampleData<T>> queue = new PriorityQueue<IntermediateSampleData<T>>(numSamples);
		double probability;
		IntermediateSampleData<T> smallest = null;
		int index = 0, k=0, gap = 0;
		int totalgap = 0; // for test
		while (input.hasNext()) {
			T element = input.next();
			if (index < THRESHOLD) {  // if index is less than THRESHOLD, then use regular reservoir
				if (index < numSamples) {
					// Fill the queue with first K elements from input.
					queue.add(new IntermediateSampleData<T>(random.nextDouble(), element));
					smallest = queue.peek();
				} else {
					double rand = random.nextDouble();
					// Remove the element with the smallest weight, and append current element into the queue.
					if (rand > smallest.getWeight()) {
						queue.remove();
						queue.add(new IntermediateSampleData<T>(rand, element));
						smallest = queue.peek();
					}
				}
				index++;
			} else {          // fast section
				double rand = random.nextDouble();
				probability = (double) numSamples / index;
				gap = (int) (Math.log(rand) / Math.log(1 - probability));
				totalgap+=gap;  //for test
				int elementCount = 0;
				while (input.hasNext() && elementCount < gap) {
					elementCount++;
					element = input.next();
					index++;
				}
				if (elementCount <gap)
					continue;
				else {
					queue.remove();
					queue.add(new IntermediateSampleData<T>(random.nextDouble(), element));
					index++;
				}
			}
		}
		return queue.iterator();
	}
}